mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ba25b02978 | |||
| 966da7b7c8 | |||
| f0af90bff3 | |||
| 1638968509 | |||
| 250fcbe62c | |||
| 0595d2336a | |||
| 387c8f1df3 | |||
| f6c2b287bc | |||
| ab188d026d | |||
| 316fc260f6 | |||
| aab7fa483b | |||
| 496364ce53 | |||
| 4446d3180f | |||
| 7b8241e424 | |||
| 8857f02083 | |||
| 1747b3d6b2 | |||
| 36623a16da | |||
| 9d3b77bffc | |||
| 2682518d0e | |||
| b27fabe504 | |||
| adf7ab5849 | |||
| 456998175f | |||
| b4afd9f170 | |||
| 73c7375b92 | |||
| 6414b1af6e | |||
| dd55290f4e | |||
| be77baea31 | |||
| a812e2b5f1 | |||
| 4ebff5aaf3 | |||
| 0687608feb | |||
| db4e1dbbec | |||
| 9442e4f9e3 | |||
| e17f7b22a6 | |||
| ce6939fc0d | |||
| 4705ef9ec2 | |||
| 9c2b48ff5d | |||
| 87906b96a7 | |||
| c0a0d46eb2 |
@@ -399,3 +399,49 @@ jobs:
|
||||
run: |
|
||||
echo "Some runtime tests failed or were cancelled"
|
||||
exit 1
|
||||
update_pr_description:
|
||||
name: Update PR Description
|
||||
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]'
|
||||
needs: [ghcr_build_runtime]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get short SHA
|
||||
id: short_sha
|
||||
run: echo "SHORT_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-7)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update PR Description
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REPO: ${{ github.repository }}
|
||||
SHORT_SHA: ${{ steps.short_sha.outputs.SHORT_SHA }}
|
||||
run: |
|
||||
echo "updating PR description"
|
||||
DOCKER_RUN_COMMAND="docker run -it --rm \
|
||||
-p 3000:3000 \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:$SHORT_SHA-nikolaik \
|
||||
--name openhands-app-$SHORT_SHA \
|
||||
ghcr.io/all-hands-ai/runtime:$SHORT_SHA"
|
||||
|
||||
PR_BODY=$(gh pr view $PR_NUMBER --json body --jq .body)
|
||||
|
||||
if echo "$PR_BODY" | grep -q "To run this PR locally, use the following command:"; then
|
||||
UPDATED_PR_BODY=$(echo "${PR_BODY}" | sed -E "s|docker run -it --rm.*|$DOCKER_RUN_COMMAND|")
|
||||
else
|
||||
UPDATED_PR_BODY="${PR_BODY}
|
||||
|
||||
---
|
||||
|
||||
To run this PR locally, use the following command:
|
||||
\`\`\`
|
||||
$DOCKER_RUN_COMMAND
|
||||
\`\`\`"
|
||||
fi
|
||||
|
||||
echo "updated body: $UPDATED_PR_BODY"
|
||||
gh pr edit $PR_NUMBER --body "$UPDATED_PR_BODY"
|
||||
|
||||
@@ -3,6 +3,8 @@ name: Resolve Issues with OpenHands
|
||||
on:
|
||||
issues:
|
||||
types: [labeled]
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
call-openhands-resolver:
|
||||
|
||||
@@ -174,6 +174,7 @@ evaluation/bird/data
|
||||
evaluation/gaia/data
|
||||
evaluation/gorilla/data
|
||||
evaluation/toolqa/data
|
||||
evaluation/scienceagentbench/benchmark
|
||||
|
||||
# frontend
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/opendevin/shared_invite/zt-2oikve2hu-UDxHeo8nsE69y6T7yFX_BA"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2tom0er4l-JeNUGHt_AxpEfIBstbLPiw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<br/>
|
||||
@@ -38,15 +38,15 @@ See the [Installation](https://docs.all-hands.dev/modules/usage/installation) gu
|
||||
system requirements and more information.
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.11
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.12
|
||||
```
|
||||
|
||||
You'll find OpenHands running at [http://localhost:3000](http://localhost:3000)!
|
||||
@@ -59,7 +59,8 @@ works best, but you have [many options](https://docs.all-hands.dev/modules/usage
|
||||
|
||||
You can also [connect OpenHands to your local filesystem](https://docs.all-hands.dev/modules/usage/runtimes),
|
||||
run OpenHands in a scriptable [headless mode](https://docs.all-hands.dev/modules/usage/how-to/headless-mode),
|
||||
or interact with it via a [friendly CLI](https://docs.all-hands.dev/modules/usage/how-to/cli-mode).
|
||||
interact with it via a [friendly CLI](https://docs.all-hands.dev/modules/usage/how-to/cli-mode),
|
||||
or run it on tagged issues with [a github action](https://github.com/All-Hands-AI/OpenHands-resolver).
|
||||
|
||||
Visit [Installation](https://docs.all-hands.dev/modules/usage/installation) for more information and setup instructions.
|
||||
|
||||
@@ -92,7 +93,7 @@ For details, please check [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
Whether you're a developer, a researcher, or simply enthusiastic about OpenHands, we'd love to have you in our community.
|
||||
Let's make software engineering better together!
|
||||
|
||||
- [Slack workspace](https://join.slack.com/t/opendevin/shared_invite/zt-2oikve2hu-UDxHeo8nsE69y6T7yFX_BA) - Here we talk about research, architecture, and future development.
|
||||
- [Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2tom0er4l-JeNUGHt_AxpEfIBstbLPiw) - Here we talk about research, architecture, and future development.
|
||||
- [Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
|
||||
|
||||
## 📈 Progress
|
||||
|
||||
@@ -50,6 +50,7 @@ LLM_API_KEY="sk_test_12345"
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
|
||||
-e SANDBOX_USER_ID=$(id -u) \
|
||||
-e WORKSPACE_MOUNT_PATH=$WORKSPACE_BASE \
|
||||
-e LLM_API_KEY=$LLM_API_KEY \
|
||||
@@ -58,7 +59,7 @@ docker run -it \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app-$(date +%Y%m%d%H%M%S) \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.11 \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.12 \
|
||||
python -m openhands.core.cli
|
||||
```
|
||||
|
||||
@@ -107,4 +108,3 @@ Expected Output:
|
||||
```bash
|
||||
🤖 An error occurred. Please try again.
|
||||
```
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ LLM_API_KEY="sk_test_12345"
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
|
||||
-e SANDBOX_USER_ID=$(id -u) \
|
||||
-e WORKSPACE_MOUNT_PATH=$WORKSPACE_BASE \
|
||||
-e LLM_API_KEY=$LLM_API_KEY \
|
||||
@@ -52,7 +53,6 @@ docker run -it \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app-$(date +%Y%m%d%H%M%S) \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.11 \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.12 \
|
||||
python -m openhands.core.main -t "write a bash script that prints hi"
|
||||
```
|
||||
|
||||
|
||||
@@ -11,15 +11,15 @@
|
||||
The easiest way to run OpenHands is in Docker.
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.11
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.12
|
||||
```
|
||||
|
||||
You can also run OpenHands in a scriptable [headless mode](https://docs.all-hands.dev/modules/usage/how-to/headless-mode), as an [interactive CLI](https://docs.all-hands.dev/modules/usage/how-to/cli-mode), or using the [OpenHands GitHub Action](https://docs.all-hands.dev/modules/usage/how-to/github-action).
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# DiscoveryBench with OpenHands
|
||||
|
||||
[DiscoveryBench](https://github.com/allenai/discoverybench/) [(Paper)](https://arxiv.org/abs/2407.01725v1) contains 264 tasks collected across 6 diverse domains, such as biology, economics, and sociology. It incorporates discovery workflows from published papers to approximate the real-world challenges faced by researchers.
|
||||
|
||||
<p align="center">
|
||||
<a href="[https://github.com/allenai/discoverybench](https://github.com/allenai/discoverybench)">
|
||||
<img src="https://raw.githubusercontent.com/allenai/discoverybench/refs/heads/main/assets/discoverybench-openhands-teaser.png" width="100%" alt="DiscoveryBench Background" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
## Setup Environment and LLM Configuration
|
||||
|
||||
1. Please follow instructions mentioned [here](https://github.com/openlocus/OpenHands/blob/discoverybench-openhands-integration/evaluation/README.md#setup) to setup OpenHands development environment and LLMs locally
|
||||
|
||||
2. Execute the bash script to start DiscoveryBench Evaluation
|
||||
|
||||
```
|
||||
./evaluation/discoverybench/scripts/run_infer.sh [YOUR MODEL CONFIG]
|
||||
```
|
||||
Replace `[YOUR MODEL CONFIG]` with any model the model that you have set up in `config.toml`
|
||||
|
||||
|
||||
## Run Inference on DiscoveryBench Instances
|
||||
|
||||
When the `run_infer.sh` script is started, it will automatically pull the latest DiscoveryBench instances & set up the agent environment. The OpenHands agent is invoked to process the task within this environment, producing a hypothesis. We then evaluate it against the “gold” hypothesis provided by DiscoveryBench. The evaluation result, along with the agent chat history is logged to `output.jsonl` under `evaluation_outputs`.
|
||||
|
||||
|
||||
```
|
||||
./evaluation/discoverybench/scripts/run_infer.sh [MODEL_CONFIG] [GIT_COMMIT] [AGENT] [EVAL_LIMIT] [NUM_WORKERS]
|
||||
```
|
||||
|
||||
- `MODEL_CONFIG`: Name of the model you want to evaluate with
|
||||
- `GIT_COMMIT`: This should be the git commit hash or release tag for OpenHands, e.g., HEAD or a specific tag like 0.6.2.
|
||||
- `AGENT`: Use CoderActAgent, right now it only supports that.
|
||||
- `EVAL_LIMIT`: Number of samples to evaluate.
|
||||
- `NUM_WORKERS`: Number of workers to parallelize the evaluation process.
|
||||
@@ -0,0 +1,7 @@
|
||||
## DiscoveryBench Evaluation Utils
|
||||
|
||||
- **`eval_w_subhypo_gen.py`**: Implements the DiscoveryBench logic for evaluating agent-generated hypotheses.
|
||||
- **`lm_utils.py`**: Provides utility functions necessary for the evaluation process.
|
||||
- **`openai_helpers.py`**: Includes helper functions for OpenAI-related tasks.
|
||||
- **`openai_semantic_gen_prompts.py`**: Contains prompts used for semantic generation.
|
||||
- **`response_parser.py`**: Handles the parsing of agent-generated hypotheses.
|
||||
@@ -0,0 +1,538 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from .lm_utils import run_chatgpt_query_multi_turn
|
||||
from .openai_helpers import get_response
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_score_from_answer(type, answer):
|
||||
if type == 'context':
|
||||
answer = answer.replace('Answer:', '').strip()
|
||||
if answer.startswith('A)'):
|
||||
return 1.0
|
||||
elif answer.startswith('B)'):
|
||||
return 0.0
|
||||
return -1.0
|
||||
|
||||
elif type == 'var':
|
||||
try:
|
||||
var_json = json.loads(answer)
|
||||
# print(f"var_json:{var_json}")
|
||||
p = 0.0
|
||||
r = 0.0
|
||||
f1 = 0.0
|
||||
if var_json['sizeB']:
|
||||
p = var_json['intersection'] / var_json['sizeB']
|
||||
if var_json['sizeA']:
|
||||
r = var_json['intersection'] / var_json['sizeA']
|
||||
if p > 0.0 and r > 0.0:
|
||||
f1 = (2 * p * r) / (p + r)
|
||||
else:
|
||||
f1 = 0.0
|
||||
eval_rec = {
|
||||
'p': p,
|
||||
'r': r,
|
||||
'f1': f1,
|
||||
'sizeA': var_json['sizeA'],
|
||||
'sizeB': var_json['sizeB'],
|
||||
'intersection': var_json['intersection'],
|
||||
'explanation': var_json['explanation'],
|
||||
}
|
||||
print(f'var_eval: {eval_rec}')
|
||||
return eval_rec
|
||||
except Exception: # COMMENT: added Exception
|
||||
return {'p': -1.0, 'r': -1.0, 'f1': -1.0}
|
||||
elif type == 'rel':
|
||||
print(answer)
|
||||
rel_json = json.loads(answer)
|
||||
answer_str = rel_json['answer'].strip()
|
||||
if answer_str.startswith('A') or 'very similar' in answer_str:
|
||||
return 1.0
|
||||
elif (
|
||||
answer_str.startswith('B') or 'similar but general than HypoA' in answer_str
|
||||
):
|
||||
return 0.5
|
||||
elif answer_str.startswith('C') or 'different' in answer_str:
|
||||
return 0.0
|
||||
return -1.0
|
||||
return -1.0
|
||||
|
||||
|
||||
def ask_dimension_question(
|
||||
query,
|
||||
gold_hypo,
|
||||
gold_workflow,
|
||||
gen_hypo,
|
||||
gen_workflow,
|
||||
dataset_meta,
|
||||
llm_used,
|
||||
dimension,
|
||||
dataset_type,
|
||||
use_column_metadata=True,
|
||||
):
|
||||
dimension_question = ''
|
||||
answer = ''
|
||||
score = 0.0
|
||||
if dimension == 'var':
|
||||
score = {'p': -1.0, 'r': -1.0, 'f1': -1.0}
|
||||
num_tokens = 256
|
||||
num_retries = 1
|
||||
json_response = False
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are an AI assistant that helps evaluate a data-driven hypothesis. You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
]
|
||||
if dimension == 'context':
|
||||
dimension_question = """\
|
||||
Question: Is HypoB defined in the same context as HypoA?
|
||||
(Context refers to assumptions/stratification under which the hypotheses are defined.)
|
||||
Options: A) same B) different
|
||||
What is your answer?"""
|
||||
elif dimension == 'var':
|
||||
dimension_question = """\
|
||||
Question: For both HypoA and HypoB, what are the different variables found in the hypotheses? \
|
||||
Return your answer as a JSON object in the following format:
|
||||
```json
|
||||
{{
|
||||
"sizeA": num of variables used in HypoA
|
||||
"sizeB": num of variables used in HypoB
|
||||
"intersection": num of variables common in HypoA and HypoB. Use *fuzzy matching* to determine intersection, accounting for paraphrases or slightly different surface forms
|
||||
"explanation": a short text explanation about the variables
|
||||
}}```
|
||||
Answer:"""
|
||||
num_tokens = 512
|
||||
num_retries = 1
|
||||
json_response = True
|
||||
elif dimension == 'rel':
|
||||
dimension_question = """\
|
||||
Question: Does HypoB exhibit the same relation as HypoA?
|
||||
Compare using following example hierarchy of relationships (based on specificity): \
|
||||
"there exists a relationship" > "positive relationship" > "positive AND (linear OR quadratic)" > "positive AND linear".
|
||||
Options: A) very similar B) similar but general than HypoA C) different
|
||||
Return your answer as a JSON object in the following format:
|
||||
```json
|
||||
{{
|
||||
"answer": one of the options from A) very similar B) similar but general than HypoA C) different
|
||||
"explanation": a short text explanation about the relationship comparison
|
||||
}}```
|
||||
Answer:"""
|
||||
num_tokens = 512
|
||||
num_retries = 1
|
||||
json_response = True
|
||||
|
||||
datasets_json = prepare_dataset_metadata_json(
|
||||
dataset_meta, dataset_type=dataset_type, use_column_metadata=use_column_metadata
|
||||
)
|
||||
|
||||
dimension_question_str = f"""\
|
||||
You are going to compare two natural-language hypotheses HypoA and HypoB accompanied with optional workflows: WorkflowA for HypoA and WorkflowB for HypoB. \
|
||||
Both the hypotheses answer the natural language query "QUERY" over the dataset(s) described by dataset description(s) and column description(s) below. \
|
||||
Compare HypoA and HypoB in terms of three aspects: Contexts, Variables, and Relations. \
|
||||
E.g., for the hypothesis "From 1995 to 2009, the number of sandhill cranes around the tundra (Indigilka River) surged by an astounding ~10X":
|
||||
* Contexts refer to stratification of the data under which the given hypothesis is True. E.g., "For all women", "From 1995 to 2009".
|
||||
* Variables refer to the set of variables (either dependent or independent) that are mentioned in the hypothesis. E.g., number of sandhill cranes, location.
|
||||
* Relations refer to the form of relation between the variables. E.g., "surged by ~10x".
|
||||
|
||||
Answer following questions for a given pair of hypotheses, HypoA and HypoB, along with an explanation grounded on the QUERY and the DATASET(S).
|
||||
|
||||
Here is the metadata for the task:
|
||||
```json
|
||||
{{
|
||||
"datasets": {datasets_json},
|
||||
"query": {query},
|
||||
"HypoA": {gold_hypo},
|
||||
"WorkflowA": {gold_workflow},
|
||||
"HypoB": {gen_hypo},
|
||||
"WorkflowB": {gen_workflow}
|
||||
}}
|
||||
```
|
||||
|
||||
{dimension_question}"""
|
||||
|
||||
messages.append({'role': 'user', 'content': dimension_question_str})
|
||||
for retry in range(num_retries):
|
||||
response = run_chatgpt_query_multi_turn(
|
||||
messages=messages,
|
||||
model_name=llm_used,
|
||||
max_tokens=num_tokens,
|
||||
temperature=0, # 0 for greedy best decoding
|
||||
json_response=json_response,
|
||||
)
|
||||
if response is not None: # COMMENT: changed from != to is not
|
||||
break
|
||||
|
||||
if response is not None: # COMMENT: changed from != to is not
|
||||
answer = response.choices[0].message.content.strip()
|
||||
score = get_score_from_answer(type=dimension, answer=answer)
|
||||
|
||||
return dimension_question, answer, score
|
||||
|
||||
|
||||
def prepare_dataset_metadata_json(dataset_meta, dataset_type, use_column_metadata=True):
|
||||
if dataset_meta is None: # COMMENT: changed from == to is None
|
||||
return [
|
||||
{
|
||||
'dataset_description': '',
|
||||
'columns': [],
|
||||
}
|
||||
]
|
||||
datasets_json = []
|
||||
if dataset_type == 'real':
|
||||
for d in dataset_meta['datasets']:
|
||||
datasets_json.append(
|
||||
{
|
||||
'dataset_description': d['description'],
|
||||
'columns': [
|
||||
{'name': col['name'], 'description': col['description']}
|
||||
for col in d['columns']['raw']
|
||||
]
|
||||
if use_column_metadata
|
||||
else [],
|
||||
}
|
||||
)
|
||||
else:
|
||||
for d in dataset_meta['datasets']:
|
||||
datasets_json.append(
|
||||
{
|
||||
'dataset_description': d['description'],
|
||||
'columns': [
|
||||
{'name': col['name'], 'description': col['description']}
|
||||
for col in d['columns']
|
||||
]
|
||||
if use_column_metadata
|
||||
else [],
|
||||
}
|
||||
)
|
||||
return datasets_json
|
||||
|
||||
|
||||
def get_sub_hypotheses(
|
||||
query,
|
||||
hypo,
|
||||
workflow,
|
||||
dataset_meta,
|
||||
llm_used,
|
||||
dataset_type,
|
||||
use_column_metadata=True,
|
||||
):
|
||||
client = OpenAI()
|
||||
extraction_prompt = """\
|
||||
Given a set of dataset columns, a ground-truth hypothesis, and the analysis workflow used, your task is to extract three dimensions that define the hypothesis: Context, Variables, and Relations. \
|
||||
Here are the definitions for these dimensions:
|
||||
- Contexts: Boundary conditions that limit the scope of a hypothesis. E.g., “for men over \
|
||||
the age of 30”, “in Asia and Europe”. If the context applies to the full dataset, then extract the context from the dataset_descrption.
|
||||
- Variables: Known concepts that interact in a meaningful way under a given context to \
|
||||
produce the hypothesis. E.g., gender, age, income, or "None" if there is no interacting variable.
|
||||
- Relations: Interactions between a given set of variables under a given context to produce \
|
||||
the hypothesis. E.g., “quadratic relationship”, “inversely proportional”, piecewise conditionals, \
|
||||
or "None" if there is no interacting relationship.
|
||||
Make sure to only use the information present in the hypothesis and the workflow. Do not add any new information. \
|
||||
For each dimension, be specific, and do not omit any important details.
|
||||
|
||||
Here is the metadata for the task:
|
||||
```json
|
||||
{
|
||||
"datasets": %s,
|
||||
"hypothesis": "%s",
|
||||
"workflow": "%s"
|
||||
}
|
||||
```
|
||||
|
||||
Return your answer as a JSON object in the following format:
|
||||
```json
|
||||
{
|
||||
"sub_hypo": [
|
||||
{
|
||||
"text": the hypothesis in natural language,
|
||||
"context": a short text description of the context of the hypothesis,
|
||||
"variables": a list of columns involved in the hypothesis,
|
||||
"relations": a short text description of the relationship between the variables of the hypothesis
|
||||
},
|
||||
...
|
||||
]
|
||||
}```
|
||||
"""
|
||||
datasets_json = prepare_dataset_metadata_json(
|
||||
dataset_meta, dataset_type, use_column_metadata=use_column_metadata
|
||||
)
|
||||
_prompt = extraction_prompt % (datasets_json, hypo, workflow)
|
||||
sub_hypo_json = get_response(client, _prompt, model=llm_used, max_retry=1)
|
||||
|
||||
if sub_hypo_json is not None: # COMMENT: changed from != to is not
|
||||
# print(f"full hypothesis: {hypo}")
|
||||
print(f'sub_hypo_json: {sub_hypo_json}')
|
||||
else:
|
||||
sub_hypo_json = {
|
||||
'sub_hypo': [],
|
||||
}
|
||||
|
||||
sub_hypo_json['full_hypo'] = hypo
|
||||
|
||||
return sub_hypo_json
|
||||
|
||||
|
||||
def match_context_with_gpt(
|
||||
gold_hyp, gold_context, pred_hyp, pred_context, model='gpt-3.5-turbo'
|
||||
):
|
||||
prompt = f"""\
|
||||
Given a gold hypothesis, a gold context, a predicted hypothesis, and a predicted context, your task is \
|
||||
to determine if the predicted context semantically matches the ground-truth context. \
|
||||
Here is the definition for Context: Boundary conditions that limit the scope of a sub-hypothesis. E.g., “for men over the age of 30”, “in Asia and Europe”. If the context applies to the full dataset, then the context is derived from the dataset_descrption. \
|
||||
Here is the definition for Context: Boundary conditions that limit the scope of a sub-hypothesis. E.g., “for men over the age of 30”, “in Asia and Europe”. If the context applies to the full dataset, then the context is derived from the dataset_descrption. \
|
||||
If the predicted context matches the gold context, return true, otherwise return false.
|
||||
If both gold and predicted hypotheses are defined over the context of the full dataset, then also return true.
|
||||
If both gold and predicted hypotheses are defined over the context of the full dataset, then also return true.
|
||||
|
||||
Here is the metadata for the task:
|
||||
```json
|
||||
{{
|
||||
"gold_hypothesis": "{gold_hyp}",
|
||||
"gold_context": "{gold_context}",
|
||||
"predicted_hypothesis": "{pred_hyp}",
|
||||
"predicted_context": "{pred_context}"
|
||||
}}
|
||||
```
|
||||
|
||||
Return your answer as a JSON object in the following format:
|
||||
```json
|
||||
{{
|
||||
"match": true or false
|
||||
}}
|
||||
```"""
|
||||
|
||||
client = OpenAI()
|
||||
output = get_response(client, prompt, model=model)
|
||||
return output.get('match', False)
|
||||
|
||||
|
||||
def is_matching_context(gold_hyp, gold_context, pred_hyp, pred_context, llm_used):
|
||||
if gold_context == pred_context:
|
||||
return True
|
||||
if 'None' in [gold_context, pred_context]:
|
||||
return False
|
||||
return match_context_with_gpt(
|
||||
gold_hyp, gold_context, pred_hyp, pred_context, model=llm_used
|
||||
)
|
||||
|
||||
|
||||
def run_eval_gold_vs_gen_NL_subhypo(
|
||||
query,
|
||||
gold_hypo,
|
||||
gold_workflow,
|
||||
gen_hypo,
|
||||
gen_workflow,
|
||||
dataset_meta,
|
||||
llm_used,
|
||||
context_score,
|
||||
dataset_type,
|
||||
use_column_metadata=True,
|
||||
):
|
||||
# GPT-4 based evaluation to evaluate generated hypothesis in terms of context, variables, relation
|
||||
|
||||
eval_rec = {
|
||||
'query': query,
|
||||
'HypoA': gold_hypo,
|
||||
'WorkflowA': gold_workflow,
|
||||
'HypoB': gen_hypo,
|
||||
'WorkflowB': gen_workflow,
|
||||
}
|
||||
|
||||
for dimension in ['var', 'rel']:
|
||||
question, answer, score = ask_dimension_question(
|
||||
query,
|
||||
gold_hypo,
|
||||
gold_workflow,
|
||||
gen_hypo,
|
||||
gen_workflow,
|
||||
dataset_meta,
|
||||
llm_used,
|
||||
dimension=dimension,
|
||||
dataset_type=dataset_type,
|
||||
use_column_metadata=use_column_metadata,
|
||||
)
|
||||
|
||||
eval_rec[dimension] = {'question': question, 'answer': answer, 'score': score}
|
||||
|
||||
eval_rec['context'] = context_score
|
||||
eval_rec['accuracy_score'] = (
|
||||
1.0
|
||||
* eval_rec['context']['score']
|
||||
* eval_rec['var']['score']['f1']
|
||||
* eval_rec['rel']['score']
|
||||
)
|
||||
|
||||
return eval_rec
|
||||
|
||||
|
||||
def run_eval_gold_vs_gen_NL_hypo_workflow(
|
||||
query,
|
||||
gold_hypo,
|
||||
gold_workflow,
|
||||
gen_hypo,
|
||||
gen_workflow,
|
||||
dataset_meta,
|
||||
llm_used,
|
||||
dataset_type,
|
||||
use_column_metadata=True,
|
||||
):
|
||||
# Input: Dataset Metadata, Query, Gold {Hg, Wg}, Predicted {Hp, Wp}
|
||||
# Output: eval_rec json includes final_score
|
||||
|
||||
# Procedure:
|
||||
# Dataset Metadata, Query, Gold {Hg, Wg}, Pred {Hg, Wg}
|
||||
# Gold: [Hg1, Hg2] (compute on the fly) Hg1 is a NL form of subhypothesis
|
||||
# Predicted: [Hp1, Hp2] (compute on the fly)
|
||||
|
||||
# Compute Intersection: [(Hg_i, Hp_j), …] # tuples of (gold,pred) that matched with context (do this w/o explicit extraction)
|
||||
# # filter so that a gold context and a predicted context are only attached to one tuple
|
||||
# Compute recall_context (programmatically)
|
||||
|
||||
# r_v_list = []
|
||||
# For (Hg_i, Hp_j) in the intersection:
|
||||
# With Hg_i, Hp_j in NL, ask GPT4 → #variables and #intersection and a paragraph explanation and programmatically calculate f1_v
|
||||
# Hg_i, Hp_j in NL, ask GPT4 → matching score (0, 0.5 or 1) : A) very similar B) similar but general than HypoA C) different + explanation
|
||||
# r_v_list ← f1_v * score_r
|
||||
# accuracy_score = mean(r_v_list)
|
||||
# score = [ recall_context * mean over predicted context(context_score * var_score *rel_score )]
|
||||
|
||||
# recall_context = 1.0 # COMMENT: never used
|
||||
eval_rec = {
|
||||
'query': query,
|
||||
'HypoA': gold_hypo,
|
||||
'WorkflowA': gold_workflow,
|
||||
'HypoB': gen_hypo,
|
||||
'WorkflowB': gen_workflow,
|
||||
}
|
||||
|
||||
gold_sub_hypo_json = get_sub_hypotheses(
|
||||
query=query,
|
||||
hypo=gold_hypo,
|
||||
workflow=gold_workflow,
|
||||
dataset_meta=dataset_meta,
|
||||
llm_used=llm_used,
|
||||
dataset_type=dataset_type,
|
||||
use_column_metadata=use_column_metadata,
|
||||
)
|
||||
if len(gold_sub_hypo_json['sub_hypo']) == 0:
|
||||
gold_sub_hypo_json['sub_hypo'] = [
|
||||
{
|
||||
'text': gold_hypo,
|
||||
'context': 'None',
|
||||
'variables': [],
|
||||
'relations': '',
|
||||
'explanation': 'unable to segment',
|
||||
}
|
||||
]
|
||||
print(f'gold_sub_hypo_json: {gold_sub_hypo_json}')
|
||||
|
||||
gen_sub_hypo_json = get_sub_hypotheses(
|
||||
query=query,
|
||||
hypo=gen_hypo,
|
||||
workflow=gen_workflow,
|
||||
dataset_meta=dataset_meta,
|
||||
llm_used=llm_used,
|
||||
dataset_type=dataset_type,
|
||||
use_column_metadata=use_column_metadata,
|
||||
)
|
||||
if len(gen_sub_hypo_json['sub_hypo']) == 0:
|
||||
gen_sub_hypo_json['sub_hypo'] = [
|
||||
{
|
||||
'text': gen_hypo,
|
||||
'context': 'None',
|
||||
'variables': [],
|
||||
'relations': '',
|
||||
'explanation': 'unable to segment',
|
||||
}
|
||||
]
|
||||
print(f'gen_sub_hypo_json: {gen_sub_hypo_json}')
|
||||
|
||||
eval_rec['gold_sub_hypo'] = gold_sub_hypo_json
|
||||
eval_rec['gen_sub_hypo'] = gen_sub_hypo_json
|
||||
|
||||
gold_subh_covered = []
|
||||
gen_subh_to_gold_subh = dict()
|
||||
gen_gold_subh_to_context = dict()
|
||||
|
||||
for p_id, gen_subh in enumerate(gen_sub_hypo_json['sub_hypo']):
|
||||
gen_subh_to_gold_subh[p_id] = -1
|
||||
|
||||
for g_id, gold_subh in enumerate(gold_sub_hypo_json['sub_hypo']):
|
||||
if g_id in gold_subh_covered:
|
||||
continue
|
||||
|
||||
# match context
|
||||
context_bool = is_matching_context(
|
||||
gold_subh['text'],
|
||||
gold_subh.get('context', ''),
|
||||
gen_subh['text'],
|
||||
gen_subh.get('context', ''),
|
||||
llm_used,
|
||||
)
|
||||
if context_bool:
|
||||
context_score = 1.0
|
||||
else:
|
||||
context_score = 0.0
|
||||
|
||||
if context_score == 1.0: # match only when context_score = 1.0
|
||||
gen_subh_to_gold_subh[p_id] = g_id
|
||||
gold_subh_covered.append(g_id)
|
||||
gen_gold_subh_to_context[f'P{p_id}||G{g_id}'] = {
|
||||
'question': f"""Comapring: GoldH: {gold_subh["text"]}, GoldC: {gold_subh['context']}\nGenH: {gen_subh['text']}, GenC: {gen_subh['context']}""",
|
||||
'answer': context_bool,
|
||||
'score': context_score,
|
||||
}
|
||||
break
|
||||
|
||||
print(f'gen_subh_to_gold_subh: {gen_subh_to_gold_subh}')
|
||||
eval_rec['gen_subh_to_gold_subh'] = gen_subh_to_gold_subh
|
||||
eval_rec['gold_subh_covered'] = gold_subh_covered
|
||||
matched_gold_gen_subh_evals = dict()
|
||||
sum_accuracy_score = 0.0
|
||||
for p_id, g_id in gen_subh_to_gold_subh.items():
|
||||
if g_id >= 0:
|
||||
key = f'P{p_id}||G{g_id}'
|
||||
context_score = gen_gold_subh_to_context[key]
|
||||
subh_eval_rec = run_eval_gold_vs_gen_NL_subhypo(
|
||||
query,
|
||||
gold_hypo,
|
||||
gold_workflow,
|
||||
gen_hypo,
|
||||
gen_workflow,
|
||||
dataset_meta,
|
||||
llm_used,
|
||||
context_score,
|
||||
dataset_type=dataset_type,
|
||||
use_column_metadata=use_column_metadata,
|
||||
)
|
||||
sum_accuracy_score += subh_eval_rec['accuracy_score']
|
||||
matched_gold_gen_subh_evals[key] = subh_eval_rec
|
||||
|
||||
eval_rec['matched_gold_gen_subh_evals'] = matched_gold_gen_subh_evals
|
||||
eval_rec['recall_context'] = (
|
||||
len(gold_subh_covered) / len(gold_sub_hypo_json['sub_hypo'])
|
||||
if len(gold_sub_hypo_json['sub_hypo'])
|
||||
else 0.0
|
||||
)
|
||||
mean_accuracy_score = (
|
||||
sum_accuracy_score / len(gen_subh_to_gold_subh)
|
||||
if len(gen_subh_to_gold_subh)
|
||||
else 0.0
|
||||
)
|
||||
eval_rec['mean_accuracy_score'] = mean_accuracy_score
|
||||
final_score = eval_rec['recall_context'] * mean_accuracy_score
|
||||
eval_rec['final_score'] = final_score
|
||||
print(f'eval_rec: {json.dumps(eval_rec, indent=2)}')
|
||||
|
||||
return eval_rec
|
||||
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from openai import OpenAI
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt, # type: ignore
|
||||
wait_random_exponential, # type: ignore
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
Model = Literal['gpt-4', 'gpt-3.5-turbo', 'text-davinci-003']
|
||||
|
||||
OpenAI.api_key = os.getenv('OPENAI_API_KEY')
|
||||
OPENAI_GEN_HYP = {
|
||||
'temperature': 0,
|
||||
'max_tokens': 250,
|
||||
'top_p': 1.0,
|
||||
'frequency_penalty': 0,
|
||||
'presence_penalty': 0,
|
||||
}
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def run_chatgpt_query_multi_turn(
|
||||
messages,
|
||||
model_name='gpt-4-turbo', # pass "gpt4" for more recent model output
|
||||
max_tokens=256,
|
||||
temperature=0.0,
|
||||
json_response=False,
|
||||
):
|
||||
response = None
|
||||
num_retries = 3
|
||||
retry = 0
|
||||
while retry < num_retries:
|
||||
retry += 1
|
||||
try:
|
||||
client = OpenAI()
|
||||
|
||||
if json_response:
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
response_format={'type': 'json_object'},
|
||||
messages=messages,
|
||||
**OPENAI_GEN_HYP,
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model=model_name, messages=messages, **OPENAI_GEN_HYP
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('GPT error. Retrying in 2 seconds...')
|
||||
time.sleep(2)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
|
||||
|
||||
def OPENAI_TOPIC_GEN_MESSAGES(n=10):
|
||||
return [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'Given `n`, come up with a list of `n` distinct topics and their descriptions. The topics can be absolutely anything. Be as creative as possible. Return your answer as a JSON object. \n\nFor example, for `n`=3, a valid answer might be:\n```json\n{{"topics": [\n {{"id": 1, "topic": "cooking", "description": "Related to recipes, ingredients, chefs, etc."}},\n {{"id": 2, "topic": "sports", "description": "Related to players, stadiums, trophies, etc."}},\n {{"id": 3, "topic": "antiquing", "description": "Related to unique items, history, etc."}}\n]}}```\n\nNow, give me a list for `n`={n}. Remember, pick diverse topics from everything possible. No consecutive topics should be broadly similar. Directly respond with the answer JSON object.',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
OPENAI_GEN_HYP = {
|
||||
'temperature': 1.0,
|
||||
'max_tokens': 4096,
|
||||
'top_p': 1.0,
|
||||
'frequency_penalty': 0,
|
||||
'presence_penalty': 0,
|
||||
}
|
||||
|
||||
|
||||
def OPENAI_SEMANTICS_GEN_MESSAGES(dependent, relationship, domain, domain_desc):
|
||||
return [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'Given the true relationship in a dataset and a given domain, your task is to come up with an interpretation of some real-world concepts that the relationship could be modeling from the provided domain. It\'s okay to be wrong, but suggest something reasonable. Try as much as possible to make sure that the TARGET is actually derivable from the other variables. Give your answer as a JSON object. Here\'s an example:\n\nRelationship for x2 = "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)"\nDomain="Sales"\nDomain description="Related to product distribution, revenues, marketing, etc."\n\nBased on this, the following real-world concepts might be applicable:\n```json\n{{\n "dependent": "x2",\n "relationship": "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)",\n "domain": "Sales",\n "trends": {{\n "x1": "Positive, cubic factor",\n "x2": "TARGET",\n "x3": "Positive, linear factor",\n "x4": "No relation",\n "x5": "Positive quadratic factor",\n "x6": "Positive, inverse quadratic factor"\n }},\n "interpretation": {{\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area", "is_target": true}},\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n}}```\n\nHere\'s a new test question:\nRelationship for {dependent} = "{relationship}"\nDomain = "{domain}"\nDomain description="{domain_desc}"\n\nRespond only with the answer JSON. Make sure that you do not forget to include the TARGET variable in the interpretation object.',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def OPENAI_SEMANTICS_GEN_W_MAP_MESSAGES(
|
||||
dependent, relationship, domain, domain_desc, mapping
|
||||
):
|
||||
return [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'Given a partial mapping from variables to real-world concepts and a true relationship in a dataset, your task is to come up with an interpretation of real-world concepts for the variables without any assigned mapping (those starting with x). Suggest something reasonable. The dependent variable must be derivable only from the other variables in the dependent relationship. Give your answer as a JSON object. Here\'s an example:\n\nExample partial mapping and relationship:\n```json\n{{\n "domain": "Sales",\n "domain_description": "Related to product distribution, revenues, marketing, etc.",\n "variable_mapping": {{\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n }},\n "dependent_variable": "sales_area",\n "dependent_relationship": "(96.4 * pop_area ** 3) + (88.72 * x5 ** 2) + (81.96 * dist_to_distr_ctr ** -2) + (28.13 * x3) + (97.0)"\n}}```\nBased on this, an example answer would be:\n```json\n{{\n "dependent_variable": "sales_area",\n "missing_mapping": ["x3", "x5"],\n "trends": {{\n "x3": "Positive, linear factor",\n "x5": "Positive quadratic factor"\n }},\n "interpretation": {{\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }}\n}}```\n\nHere\'s a new test question:\n```json\n{{\n "domain": "{domain}",\n "domain_description": "{domain_desc}",\n "variable_mapping": {json.dumps(mapping, indent=2)},\n "dependent_variable": "{dependent}",\n "dependent_relationship": "{relationship}"\n}}```\nRespond only with the answer JSON.',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def OPENAI_SEMANTICS_GEN_SUMMARY_MESSAGES(dataset):
|
||||
return [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'Given the following descriptions of the columns of a dataset, your task is to come up with a natural language overview of the dataset, which should include (1) what the dataset is about, (2) how the data was collected, (3) when the data was collected, and (3) for what purpose the data was collected. Be specific and creative.\n\nExample dataset:\n```json\n{{ \n "dataset": {{ \n "x6": {{"description": "Ancient artifact significance score", "name": "artifact_significance_score", "is_target": true}},\n "x1": {{"description": "Distance to ancient city center", "name": "dist_to_ancient_city_ctr"}},\n "x2": {{"description": "Quantity of discovered relics", "name": "relic_discovery_qty"}},\n "x3": {{"description": "Years since last archaeological expedition", "name": "years_since_exp"}},\n "x4": {{"description": "Number of artifacts in excavation site", "name": "artifact_qty"}},\n "x5": {{"description": "Soil fertility coefficient", "name": "soil_fertility_coef"}},\n "x7": {{"description": "Distance to ancient burial grounds", "name": "dist_to_burial_grounds"}},\n "x8": {{"description": "Population estimate of ancient civilization", "name": "ancient_civilization_pop_estimate"}},\n "x9": {{"description": "Temperature variation in excavation region", "name": "temp_variation"}}\n }}\n}}```\nExample description:\nThis dataset is about archaeological explorations and findings linked to ancient civilizations. The data was collected in the form of field metrics during various archaeological expeditions during the late mid-20th century. The purpose of the data collection is to evaluate the significance of ancient artifacts discovered during excavations.\n\nHere is a new test dataset.\n{json.dumps(dataset, indent=2)}\nProvide only the description.',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def OPENAI_GEN_HYPO_MESSAGES(dataset):
|
||||
return [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'Given a dataset with its descriptions and the true functional relationship between its variables, your task is to generate 3 levels of hypotheses for the stated relationship in plain English. The three levels are "broad", "medium" and "narrow". Make sure that the hypotheses sound natural. *Only include concepts for variables that are present in the provided functional relationship.* Give your answer as a JSON.\n\nFor example, an example dataset might be the following:\n```json\n{{\n "domain": "cybersecurity",\n "summary": "This dataset is about measuring cybersecurity threats in a system. The data was collected by monitoring various cybersecurity metrics in a network environment. The purpose of the data collection is to assess and predict potential cybersecurity risks and vulnerabilities.",\n "variables": [\n {{\n "description": "Level of cybersecurity threat",\n "name": "cybersecurity_threat",\n "is_target": true\n }},\n {{\n "description": "Number of failed login attempts",\n "name": "failed_login_attempts"\n }},\n {{\n "description": "Amount of encrypted data",\n "name": "encrypted_data"\n }},\n {{\n "description": "Frequency of software updates",\n "name": "software_updates"\n }},\n {{\n "description": "Number of antivirus software installed",\n "name": "antivirus_software"\n }},\n {{\n "description": "Quality of firewall protection",\n "name": "firewall_quality"\n }}\n ],\n "relationship": {{\n "dependent": "cybersecurity_threat",\n "relation": "-53.5*encrypted_data**2 - 53.85*failed_login_attempts**2 + 67.75*firewall_quality - 92.16 - 36.68/software_updates**3"\n }}\n}}```\nGiven this dataset, the following is a valid answer:\n```json\n{{\n "broad": {{\n "instruction": "Be vague. Only indicate which concepts might be related but not how they are related",\n "hypothesis": "Threat to cybersecurity is influenced by several factors including the amount of encrypted data, the number of failed login attempts, the quality of the firewall, as well as how often the software is updated."\n }},\n "medium": {{\n "instruction": "Be slightly more specific. For each factor, indicate carefully whether it positively or negatively affects the relationship, but do not indicate what the exponent is.",\n "hypothesis": "Cybersecurity threat tends to decrease with the amount of data encryption, the number of failed login attempts, as well as the frequency of software updates to some extent, while improvement in the firewall quality has a positive effect."\n }},\n "narrow": {{\n "instruction": "Be specific. Communicate the concepts, whether there is a positive or negative effect (be careful), and the meaning of the exponent",\n "hypothesis": "The threat to cybersecurity interacts in a complex manner with various factors. As the amount of encrypted data increases, there is a quadratic decrease in threat. Similarly for the number of failed login attempts, there is a negative quadratic relationship. The quality of the firewall protection on the other hand demonstrates a positive and linear relationship. Finally, the frequency of software updates has an inverse cubic relationship to the threat."\n }},\n}}\n```\n\nBased on this, provide an answer for the following test dataset:\n```json\n{dataset}```\nRespond only with a JSON.',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def create_prompt(usr_msg):
|
||||
return [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
|
||||
},
|
||||
{'role': 'user', 'content': usr_msg},
|
||||
]
|
||||
|
||||
|
||||
def get_response(client, prompt, max_retry=5, model='gpt-3.5-turbo', verbose=False):
|
||||
n_try = 0
|
||||
while n_try < max_retry:
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=create_prompt(prompt), **OPENAI_GEN_HYP
|
||||
)
|
||||
|
||||
# COMMENT: changed from
|
||||
# response.choices[0].message.content.strip().strip('```json').strip('```')
|
||||
content = response.choices[0].message.content
|
||||
cleaned_content = content.split('```json')[1].split('```')[0].strip()
|
||||
output = cleaned_content
|
||||
try:
|
||||
response_json = json.loads(output)
|
||||
return response_json
|
||||
except ValueError:
|
||||
if verbose:
|
||||
print(f'Bad JSON output:\n\n{output}')
|
||||
n_try += 1
|
||||
if n_try < max_retry:
|
||||
if verbose:
|
||||
print('Retrying...')
|
||||
else:
|
||||
if verbose:
|
||||
print('Retry limit reached')
|
||||
return None
|
||||
|
||||
|
||||
def get_code_fix(
|
||||
client, code, error, max_retry=5, model='gpt-3.5-turbo', verbose=False
|
||||
):
|
||||
prompt = f"""\
|
||||
Given the following code snippet and error message, provide a single-line fix for the error. \
|
||||
Note that the code is going to be executed using python `eval`. \
|
||||
The code should be executable and should not produce the error message. Be as specific as possible.
|
||||
|
||||
Here's the code and the error:
|
||||
{{
|
||||
"code": "{code}",
|
||||
"error": "{error}"
|
||||
}}
|
||||
|
||||
Return only a JSON object with the fixed code in the following format:
|
||||
```json
|
||||
{{
|
||||
"fixed_code": "..."
|
||||
}}"""
|
||||
response = get_response(
|
||||
client, prompt, max_retry=max_retry, model=model, verbose=verbose
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def get_new_hypothesis(
|
||||
client, target, old, expr, cols, model='gpt-3.5-turbo', verbose=False
|
||||
):
|
||||
prompt = f"""\
|
||||
Given a target column from a dataset, a pandas expression to derive the column from existing columns, a list of \
|
||||
existing columns, and a previously written hypothesis text, carefully check if the hypothesis text is consistent with \
|
||||
the pandas expression or not. If it is consistent, simply return the hypothesis as it is. If it is not consistent, \
|
||||
provide a new natural language hypothesis that is consistent with the pandas expression using only the provided \
|
||||
information. Be specific.
|
||||
|
||||
Here's the information:
|
||||
```json
|
||||
{{
|
||||
"target_column": "{target}",
|
||||
"pandas_expression": "{expr}",
|
||||
"existing_columns": {json.dumps(cols, indent=4)}
|
||||
"old_hypothesis": "{old}",
|
||||
}}```
|
||||
|
||||
Give your answer as a new JSON with the following format:
|
||||
```json
|
||||
{{
|
||||
"hypothesis": "..."
|
||||
}}"""
|
||||
response = get_response(client, prompt, model=model, verbose=verbose)
|
||||
return response
|
||||
|
||||
|
||||
def replace_variable(client, expr, old, new, model='gpt-3.5-turbo', verbose=False):
|
||||
prompt = f"""\
|
||||
Given a pandas "expression", replace mentions of the "old" column with its "new" value such that the resultant \
|
||||
expression is equivalent to the original expression.
|
||||
|
||||
Here's the information:
|
||||
```json
|
||||
{{
|
||||
"expression": "{expr}",
|
||||
"old": "{old}",
|
||||
"new": "{new}"
|
||||
}}```
|
||||
|
||||
Give your answer as a new JSON with the following format:
|
||||
```json
|
||||
{{
|
||||
"new_expression": "..."
|
||||
}}"""
|
||||
response = get_response(client, prompt, model=model, verbose=verbose)
|
||||
return response
|
||||
@@ -0,0 +1,151 @@
|
||||
common_hypothesis_features = [
|
||||
'1-2 sentences',
|
||||
'surprising finding',
|
||||
'includes numeric concepts',
|
||||
'includes categorical concepts',
|
||||
'includes binary concepts',
|
||||
]
|
||||
hypothesis_features = [
|
||||
['requires within-cluster analysis'],
|
||||
['requires across-cluster analysis'],
|
||||
['corresponds to a polynomial relationship of some columns'],
|
||||
['corresponds to a ratio between some columns'],
|
||||
['requires temporal analysis'],
|
||||
['relationship is based on descriptive statistics of some columns'],
|
||||
['requires concepts based on percentage or percentiles'],
|
||||
['relationship is only applicable to one cluster in the data and not the others'],
|
||||
]
|
||||
|
||||
column_features = [
|
||||
[
|
||||
'must have one target column',
|
||||
'must have quantifiable columns',
|
||||
'must have a few categorical columns',
|
||||
'make sure the categorical column values do not contain special characters',
|
||||
'include a few distractor columns',
|
||||
]
|
||||
]
|
||||
|
||||
common_pandas_features = [
|
||||
'must be executable using python `eval` to create the target column in variable `df` (pandas dataframe)',
|
||||
"for e.g., df['A']**2 + 3*df['B'] + 9, np.where(df['A'] > 3, 'Yes', 'No'), etc.",
|
||||
'variables in pandas_expression must be from the existing columns listed above',
|
||||
'variables in pandas_expression must NOT contain the target column itself',
|
||||
]
|
||||
pandas_features = [
|
||||
['expression is a quadratic polynomial'],
|
||||
['expression is a cubic polynomial'],
|
||||
['expression is a ratio of existing columns'],
|
||||
['expression is derived through logical combination of existing columns'],
|
||||
# workflow
|
||||
]
|
||||
pandas_features = [common_pandas_features + p for p in pandas_features]
|
||||
|
||||
common_derived_features = [
|
||||
'1-2 sentences',
|
||||
'includes numeric concepts',
|
||||
'includes categorical concepts',
|
||||
'includes binary concepts',
|
||||
]
|
||||
derived_features = [common_derived_features + h for h in hypothesis_features]
|
||||
hypothesis_features = [common_hypothesis_features + h for h in hypothesis_features]
|
||||
|
||||
PROMPT_HYP = """\
|
||||
Given a dataset topic and description, generate an interesting hypothesis based on \
|
||||
the provided instructions. Be creative and come up with an unusual finding.
|
||||
|
||||
```json
|
||||
{
|
||||
"topic": "%s",
|
||||
"description": "%s",
|
||||
"hypothesis_features": %s,
|
||||
"hypothesis": "..."
|
||||
}```
|
||||
|
||||
Give your answer as a new JSON with the following format:
|
||||
```json
|
||||
{
|
||||
"hypothesis": "..."
|
||||
}
|
||||
```"""
|
||||
|
||||
PROMPT_COL = """\
|
||||
Given a dataset topic, its description, and a true hypothesis that can be determined from it, \
|
||||
generate a list of valid columns based on the provided instructions.
|
||||
|
||||
```json
|
||||
{
|
||||
"topic": "%s",
|
||||
"description": "%s",
|
||||
"hypothesis": "%s",
|
||||
"column_instructions": %s,
|
||||
"columns": [
|
||||
{
|
||||
"col_name": "...", # should be an "_"-separated string
|
||||
"description": "...",
|
||||
"data_type": "...", # should be executable using python's `eval` function. E.g., str, float, int, bool
|
||||
"data_range": {...}, # should be either {"min": ..., "max": ...} or {"values": [...]}
|
||||
"is_distractor": true/false, # boolean indicating whether this is a distractor that could cause confusion during data analysis
|
||||
"is_target": true/false # boolean indicating whether this is the target variable for the hypothesis; at least one column should be the target
|
||||
},
|
||||
...
|
||||
],
|
||||
"pandas_instructions": %s,
|
||||
"pandas_equation_for_hypothesis": {
|
||||
"target_col": "...",
|
||||
"target_col_type": "...",
|
||||
"target_col_range": {...},
|
||||
"independent_cols_in_pandas_expression": [], # list of column names that will be used to derive the target column
|
||||
"pandas_expression": "..." # expression to derive df[target_col] using df[ind_col1], df[ind_col2], etc.
|
||||
}
|
||||
}```
|
||||
|
||||
Give your answer as a new JSON with the "columns" and "pandas_equation_for_hypothesis" keys filled using the following format:
|
||||
```json
|
||||
{
|
||||
"columns": [...],
|
||||
"pandas_equation_for_hypothesis": {...}
|
||||
}
|
||||
```"""
|
||||
|
||||
PROMPT_DER = """\
|
||||
Given a dataset topic, description, a true hypothesis that can be determined from the data, \
|
||||
and a target column from the dataset, generate a hypothesis for the target column using new independent columns not present in the existing columns.
|
||||
|
||||
```json
|
||||
{
|
||||
"topic": "%s",
|
||||
"description": "%s",
|
||||
"hypothesis": "%s",
|
||||
"existing_columns": %s,
|
||||
"target_column": "%s",
|
||||
"new_to_target_instructions": %s,
|
||||
"new_to_target_hypothesis": "...", # describe a relationship between new columns that explains the target column
|
||||
"new_columns_for_target": [ # do not repeat any of the existing columns in the dataset
|
||||
{
|
||||
"col_name": "...", # should be an "_"-separated string
|
||||
"description": "...",
|
||||
"data_type": "...", # should be executable using python's `eval` function. E.g., str, float, int, bool
|
||||
"data_range": {...}, # should be either {"min": ..., "max": ...} or {"values": [...]}
|
||||
},
|
||||
...
|
||||
],
|
||||
"pandas_instructions": %s,
|
||||
"pandas_equation_for_new_to_target_hypothesis": {
|
||||
"target_col": "...",
|
||||
"target_col_type": "...",
|
||||
"target_col_range": {...},
|
||||
"independent_cols_in_pandas_expression": [], # list of column names from new_columns_for_target that will be used to derive target_col
|
||||
"pandas_expression": "..." # expression to derive df[target_col] using df[ind_col1], df[ind_col2], etc.
|
||||
}
|
||||
}```
|
||||
|
||||
Give your answer as a new JSON with the "new_to_target_hypothesis", "new_columns_for_target", and \
|
||||
"pandas_equation_for_new_to_target_hypothesis" keys filled using the following format:
|
||||
```json
|
||||
{
|
||||
"new_to_target_hypothesis": "...",
|
||||
"new_columns_for_target": [...],
|
||||
"pandas_equation_for_new_to_target_hypothesis": {...}
|
||||
}
|
||||
```"""
|
||||
@@ -0,0 +1,52 @@
|
||||
workflow_summary_markers = [
|
||||
'WORKFLOW SUMMARY',
|
||||
'WORKFLOW_SUMMARY',
|
||||
'WORKFLOW-SUMMARY',
|
||||
'Workflow Summary',
|
||||
]
|
||||
|
||||
final_answer_markers = [
|
||||
'FINAL ANSWER',
|
||||
'FINAL_ANSWER',
|
||||
'FINAL-ANSWER',
|
||||
'Final Answer',
|
||||
'Scientific Hypothesis',
|
||||
'Hypothesis',
|
||||
]
|
||||
|
||||
next_agent_markers = [
|
||||
'NEXT AGENT',
|
||||
'NEXT-AGENT',
|
||||
'NEXT_AGENT',
|
||||
'FEEDBACK',
|
||||
]
|
||||
|
||||
|
||||
def extract_between(content, start_markers, end_markers=None):
|
||||
for marker in start_markers:
|
||||
if marker in content:
|
||||
result = content.split(marker, 1)[1]
|
||||
if end_markers:
|
||||
for end_marker in end_markers:
|
||||
if end_marker in result:
|
||||
result = result.split(end_marker, 1)[0]
|
||||
return result
|
||||
return ''
|
||||
|
||||
|
||||
def extract_gen_hypo_from_logs(content: str):
|
||||
error = ''
|
||||
|
||||
gen_workflow = extract_between(
|
||||
content, workflow_summary_markers, final_answer_markers
|
||||
)
|
||||
|
||||
if not gen_workflow:
|
||||
error += 'No Workflow Summary found in the line. | '
|
||||
|
||||
gen_hypothesis = extract_between(content, final_answer_markers, next_agent_markers)
|
||||
|
||||
if not gen_hypothesis:
|
||||
error += 'No Final Answer in the line.'
|
||||
|
||||
return gen_hypothesis, gen_workflow, error
|
||||
@@ -0,0 +1,491 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
import git
|
||||
import pandas as pd
|
||||
|
||||
from evaluation.discoverybench.eval_utils.eval_w_subhypo_gen import (
|
||||
run_eval_gold_vs_gen_NL_hypo_workflow,
|
||||
)
|
||||
from evaluation.discoverybench.eval_utils.response_parser import (
|
||||
extract_gen_hypo_from_logs,
|
||||
)
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AgentConfig,
|
||||
AppConfig,
|
||||
SandboxConfig,
|
||||
get_llm_config_arg,
|
||||
parse_arguments,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
EVALUATION_LLM = 'gpt-4-1106-preview'
|
||||
|
||||
DATA_FILES = {}
|
||||
|
||||
LIBRARIES = [
|
||||
'pandas',
|
||||
'numpy',
|
||||
'scipy',
|
||||
'matplotlib',
|
||||
'seaborn',
|
||||
'scikit-learn',
|
||||
'statsmodels',
|
||||
]
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
}
|
||||
|
||||
AGENT_CLS_TO_INST_SUFFIX = {
|
||||
'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n'
|
||||
}
|
||||
|
||||
|
||||
def get_config(
|
||||
metadata: EvalMetadata,
|
||||
) -> AppConfig:
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
runtime='eventstream',
|
||||
max_iterations=metadata.max_iterations,
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image='python:3.12-bookworm',
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
agent_config = AgentConfig(
|
||||
function_calling=False,
|
||||
codeact_enable_jupyter=True,
|
||||
codeact_enable_browsing_delegate=True,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
return config
|
||||
|
||||
|
||||
def get_dv_query_for_real(
|
||||
datasets, question, domain_knowledge=None, workflow_tags=None
|
||||
):
|
||||
"""
|
||||
Prepare a structured query for the agent to execute on the specified datasets.
|
||||
|
||||
This function constructs a query by compiling metadata from the provided datasets, along with any relevant domain knowledge and workflow tags.
|
||||
|
||||
Args:
|
||||
datasets: List of datasets
|
||||
question: Query to be answered
|
||||
domain_knowledge: Domain knowledge if any
|
||||
workflow_tags: Workflow tags if any
|
||||
|
||||
Returns:
|
||||
query_to_dv: Query to be run on the dataset
|
||||
dataset_meta: Metadata of the dataset
|
||||
"""
|
||||
|
||||
dataset_meta = ''
|
||||
for dataset_metadata in datasets:
|
||||
dataset_meta += 'Dataset name: ' + dataset_metadata['name']
|
||||
dataset_meta += 'Dataset description: ' + dataset_metadata['description']
|
||||
dataset_meta += '\nBrief description of columns: '
|
||||
for col in dataset_metadata['columns']['raw']:
|
||||
dataset_meta += col['name'] + ': ' + col['description'] + ', '
|
||||
|
||||
query_to_dv = dataset_meta
|
||||
|
||||
query_to_dv += f'\nQuery: {question}'
|
||||
|
||||
if domain_knowledge:
|
||||
query_to_dv += (
|
||||
'\nAdditionally, we provide some hints that might be useful to solve the task. Domain Knowledge: \n'
|
||||
+ domain_knowledge
|
||||
+ '.\n'
|
||||
)
|
||||
|
||||
if workflow_tags:
|
||||
query_to_dv += 'The meta tags are: ' + workflow_tags + '.\n'
|
||||
|
||||
query_to_dv += (
|
||||
'In the final answer, please write down a scientific hypothesis in '
|
||||
'natural language, derived from the provided dataset, clearly stating the '
|
||||
'context of hypothesis (if any), variables chosen (if any) and '
|
||||
'relationship between those variables (if any) including any statistical significance.'
|
||||
'Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY:'
|
||||
)
|
||||
|
||||
# Run the NL query through datavoyager
|
||||
return query_to_dv, dataset_meta
|
||||
|
||||
|
||||
def initialize_runtime(runtime: Runtime, data_files: list[str]):
|
||||
"""
|
||||
Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
"""
|
||||
logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
|
||||
obs: CmdOutputObservation
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
for file in data_files:
|
||||
runtime.copy_to(
|
||||
file,
|
||||
'/workspace',
|
||||
)
|
||||
|
||||
for lib in LIBRARIES:
|
||||
action = CmdRunAction(command=f'pip install {lib}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
def get_last_agent_finish_action(state: State) -> AgentFinishAction:
|
||||
for event in state.history.get_events(reverse=True):
|
||||
if isinstance(event, AgentFinishAction):
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def get_last_message_action(state: State) -> MessageAction:
|
||||
for event in state.history.get_events(reverse=True):
|
||||
if isinstance(event, MessageAction):
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def complete_runtime(state: State):
|
||||
last_agent_finish_action = get_last_agent_finish_action(state)
|
||||
last_agent_message_action = get_last_message_action(state)
|
||||
|
||||
if last_agent_finish_action is not None:
|
||||
final_message_1 = last_agent_finish_action.thought
|
||||
gen_hypo_1, gen_workflow_1, error_1 = extract_gen_hypo_from_logs(
|
||||
final_message_1
|
||||
)
|
||||
else:
|
||||
gen_hypo_1, gen_workflow_1, error_1 = '', '', ''
|
||||
|
||||
if last_agent_message_action is not None:
|
||||
final_message_2 = last_agent_message_action.content
|
||||
gen_hypo_2, gen_workflow_2, error_2 = extract_gen_hypo_from_logs(
|
||||
final_message_2
|
||||
)
|
||||
else:
|
||||
gen_hypo_2, gen_workflow_2, error_2 = '', '', ''
|
||||
|
||||
if gen_hypo_1 and gen_hypo_2:
|
||||
test_result = {
|
||||
'gen_hypo': last_agent_finish_action.thought
|
||||
if last_agent_finish_action
|
||||
else last_agent_message_action.content,
|
||||
'gen_workflow': '',
|
||||
'error': '',
|
||||
}
|
||||
return test_result
|
||||
|
||||
test_result = {
|
||||
'gen_hypo': gen_hypo_1 if gen_hypo_1 else gen_hypo_2,
|
||||
'gen_workflow': gen_workflow_1 if gen_workflow_1 else gen_workflow_2,
|
||||
'error': error_1 if error_1 else error_2,
|
||||
}
|
||||
|
||||
return test_result
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
):
|
||||
"""
|
||||
Process and evaluate a single instance of the dataset.
|
||||
|
||||
This function executes the OpenHands agent
|
||||
for a specific instance of the dataset. It retrieves
|
||||
the agent's results and evaluates them against the gold
|
||||
hypothesis.
|
||||
|
||||
Args:
|
||||
instance: A single row of the dataset
|
||||
metadata: Metadata for the evaluation
|
||||
reset_logger: Whether to reset the logger
|
||||
|
||||
Returns:
|
||||
output: EvalOutput object
|
||||
"""
|
||||
|
||||
config = get_config(metadata)
|
||||
|
||||
# use a session id for concurrent evaluation
|
||||
sid = 'ID_' + str(instance.instance_id)
|
||||
|
||||
# Setup the logger properly, so you can run
|
||||
# multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
problem_statement, dataset_metadata = get_dv_query_for_real(
|
||||
datasets=instance.datasets,
|
||||
question=instance.query,
|
||||
domain_knowledge=instance.domain_knowledge,
|
||||
workflow_tags=instance.workflow_tags,
|
||||
)
|
||||
|
||||
# Prepare instruction
|
||||
instruction = (
|
||||
f'You are a discovery agent who can execute a python code only once to answer a query based on one or more datasets. The datasets will be present in the current directory.\n\n'
|
||||
'Environment has been set up for you to start working. You may assume all necessary tools and datasets are installed.\n\n'
|
||||
'# Problem Statement\n'
|
||||
f'{problem_statement}\n\n'
|
||||
)
|
||||
instruction += (
|
||||
'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
|
||||
'You should NOT modify any existing test case files. If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
|
||||
'You SHOULD INCLUDE PROPER INDENTATION in your edit commands.\n'
|
||||
)
|
||||
# NOTE: You can actually set slightly different instruction for different agents
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
call_async_from_sync(runtime.connect)
|
||||
initialize_runtime(runtime, instance.data_files)
|
||||
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
test_result = complete_runtime(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
|
||||
# DiscoveryBench Evaluation
|
||||
eval_rec = run_eval_gold_vs_gen_NL_hypo_workflow(
|
||||
query=instance.query,
|
||||
gold_hypo=instance.gold_hypo,
|
||||
gold_workflow='',
|
||||
gen_hypo=test_result['gen_hypo'],
|
||||
gen_workflow='',
|
||||
dataset_meta=instance.dataset_metadata,
|
||||
llm_used=EVALUATION_LLM,
|
||||
dataset_type='real',
|
||||
)
|
||||
|
||||
test_result['eval_rec'] = eval_rec
|
||||
|
||||
output = EvalOutput(
|
||||
instance_id=str(instance.instance_id),
|
||||
instruction=instruction,
|
||||
metadata=metadata,
|
||||
history=histories,
|
||||
metrics=metrics,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
test_result=test_result,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def update_csv_name(name):
|
||||
name = name.replace('-', '_')
|
||||
|
||||
if 'meta_regression' in name:
|
||||
name = name.replace('meta_regression', 'meta-regression')
|
||||
if 'ML_enabled' in name:
|
||||
name = name.replace('ML_enabled', 'ML-enabled')
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def list_csv_files(list_of_datasets):
|
||||
res = []
|
||||
for ele in list_of_datasets:
|
||||
for key, value in ele.items():
|
||||
if key == 'name':
|
||||
csv_file_name = update_csv_name(value)
|
||||
res.append(DATA_FILES[csv_file_name])
|
||||
return res
|
||||
|
||||
|
||||
def create_dataset(repo_location: str, split: str = 'test'):
|
||||
"""
|
||||
Create a dataset from the discoverybench repository
|
||||
by walking through the repository and extracting metadata
|
||||
from the metadata_{}.json files
|
||||
|
||||
Args:
|
||||
repo_location: Location of the repository
|
||||
split: Split of the dataset to use
|
||||
|
||||
Returns:
|
||||
df: DataFrame containing the dataset instances
|
||||
"""
|
||||
|
||||
data_dict = {}
|
||||
|
||||
data_location = os.path.join(repo_location, 'discoverybench', 'real', split)
|
||||
answer_key_location = os.path.join(repo_location, 'eval', 'answer_key_real.csv')
|
||||
|
||||
idx = 0
|
||||
|
||||
for root, dirs, files in os.walk(data_location):
|
||||
for file in files:
|
||||
if file.endswith('.json'):
|
||||
if 'metadata' in file:
|
||||
metadata = json.load(open(os.path.join(root, file)))
|
||||
|
||||
dataset = root.split('/')[-1]
|
||||
metadata_id = file.split('_')[-1].split('.')[0]
|
||||
domain = metadata.get('domain', '')
|
||||
domain_knowledge = metadata.get('domain_knowledge', '')
|
||||
workflow_tags = metadata.get('workflow_tags', '')
|
||||
datasets = metadata.get('datasets', [])
|
||||
queries = metadata.get('queries', [])
|
||||
gold_workflow = metadata.get('workflow')
|
||||
|
||||
# loop through queries list to get queries
|
||||
# and each query has qid; add that to dictionary
|
||||
for query in queries[0]:
|
||||
qid = query.get('qid', '')
|
||||
|
||||
data = {
|
||||
'dataset': dataset,
|
||||
'metadata_id': metadata_id,
|
||||
'qid': qid,
|
||||
'domain': domain,
|
||||
'domain_knowledge': domain_knowledge,
|
||||
'workflow_tags': workflow_tags,
|
||||
'datasets': datasets,
|
||||
'question_type': query['question_type'],
|
||||
'query': query['question'],
|
||||
'gold_workflow': gold_workflow,
|
||||
'dataset_metadata': metadata,
|
||||
}
|
||||
|
||||
data_dict[idx] = data
|
||||
idx += 1
|
||||
|
||||
if file.endswith('.csv'):
|
||||
DATA_FILES[file] = os.path.join(root, file)
|
||||
if file.endswith('.txt'):
|
||||
DATA_FILES[file] = os.path.join(root, file)
|
||||
|
||||
df = pd.DataFrame.from_dict(data_dict, orient='index')
|
||||
|
||||
df['instance_id'] = df.index
|
||||
|
||||
df['data_files'] = df['datasets'].apply(lambda x: list_csv_files(x))
|
||||
|
||||
answer_key = pd.read_csv(answer_key_location)
|
||||
|
||||
answer_key = answer_key.rename(
|
||||
columns={
|
||||
'metadataid': 'metadata_id',
|
||||
'query_id': 'qid',
|
||||
'gold_hypothesis': 'gold_hypothesis',
|
||||
}
|
||||
)
|
||||
|
||||
df['qid'] = df['qid'].astype(int)
|
||||
df['metadata_id'] = df['metadata_id'].astype(int)
|
||||
|
||||
answer_key['qid'] = answer_key['qid'].astype(int)
|
||||
answer_key['metadata_id'] = answer_key['metadata_id'].astype(int)
|
||||
|
||||
df = pd.merge(df, answer_key, on=['dataset', 'metadata_id', 'qid'], how='left')
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
# clone git repositor for csv files
|
||||
repo_url = 'https://github.com/allenai/discoverybench.git'
|
||||
repo_location = 'git-discoverybench-allenai'
|
||||
|
||||
try:
|
||||
git.Repo.clone_from(repo_url, repo_location)
|
||||
except git.exc.GitCommandError:
|
||||
print('Repository already exists')
|
||||
|
||||
dataset = create_dataset(repo_location)
|
||||
|
||||
# check if there is any empty csv_file
|
||||
if dataset['data_files'].isnull().any():
|
||||
raise ValueError('Some csv files are missing.')
|
||||
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
metadata = make_metadata(
|
||||
llm_config,
|
||||
'discoverybench-python',
|
||||
args.agent_cls,
|
||||
args.max_iterations,
|
||||
args.eval_note,
|
||||
args.eval_output_dir,
|
||||
)
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
|
||||
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
+46
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
set -eo pipefail
|
||||
|
||||
source "evaluation/utils/version_control.sh"
|
||||
|
||||
MODEL_CONFIG=$1
|
||||
COMMIT_HASH=$2
|
||||
AGENT=$3
|
||||
EVAL_LIMIT=$4
|
||||
NUM_WORKERS=$5
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
NUM_WORKERS=1
|
||||
echo "Number of workers not specified, use default $NUM_WORKERS"
|
||||
fi
|
||||
|
||||
# ################################################################################
|
||||
|
||||
checkout_eval_branch
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
fi
|
||||
|
||||
get_agent_version
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "AGENT_VERSION: $AGENT_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
|
||||
COMMAND="poetry run python evaluation/discoverybench/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations 10 \
|
||||
--max-chars 10000000 \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note $AGENT_VERSION"
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
@@ -13,6 +13,7 @@ from evaluation.utils.shared import (
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
@@ -55,18 +56,14 @@ def get_config(
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
if metadata.llm_config.log_completions:
|
||||
metadata.llm_config.log_completions_folder = os.path.join(
|
||||
metadata.eval_output_dir, 'llm_completions', instance_id
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config, metadata.eval_output_dir, instance_id
|
||||
)
|
||||
logger.info(
|
||||
f'Logging LLM completions for instance {instance_id} to '
|
||||
f'{metadata.llm_config.log_completions_folder}'
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
)
|
||||
agent_config = AgentConfig(
|
||||
codeact_enable_jupyter=True,
|
||||
codeact_enable_browsing_delegate=True,
|
||||
codeact_enable_browsing=True,
|
||||
codeact_enable_llm_editor=False,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from evaluation.integration_tests.tests.base import BaseIntegrationTest, TestResult
|
||||
from openhands.events.action import AgentFinishAction, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentDelegateObservation
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
class Test(BaseIntegrationTest):
|
||||
INSTRUCTION = 'Look at https://github.com/All-Hands-AI/OpenHands/pull/8, and tell me what is happening there and what did @asadm suggest.'
|
||||
|
||||
@classmethod
|
||||
def initialize_runtime(cls, runtime: Runtime) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult:
|
||||
# check if the "The answer is OpenHands is all you need!" is in any message
|
||||
message_actions = [
|
||||
event
|
||||
for event in histories
|
||||
if isinstance(
|
||||
event, (MessageAction, AgentFinishAction, AgentDelegateObservation)
|
||||
)
|
||||
]
|
||||
for event in message_actions:
|
||||
if isinstance(event, AgentDelegateObservation):
|
||||
content = event.content
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
content = event.outputs.get('content', '')
|
||||
elif isinstance(event, MessageAction):
|
||||
content = event.content
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
|
||||
if (
|
||||
'non-commercial' in content
|
||||
or 'MIT' in content
|
||||
or 'Apache 2.0' in content
|
||||
):
|
||||
return TestResult(success=True)
|
||||
return TestResult(
|
||||
success=False,
|
||||
reason=f'The answer is not found in any message. Total messages: {len(message_actions)}. Messages: {message_actions}',
|
||||
)
|
||||
@@ -10,10 +10,12 @@ import pandas as pd
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
@@ -29,7 +31,10 @@ from openhands.events.action import (
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.observation import (
|
||||
BrowserOutputObservation,
|
||||
CmdOutputObservation,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.browser.browser_env import (
|
||||
BROWSER_EVAL_GET_GOAL_ACTION,
|
||||
@@ -37,7 +42,11 @@ from openhands.runtime.browser.browser_env import (
|
||||
)
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
SUPPORTED_AGENT_CLS = {'BrowsingAgent'}
|
||||
SUPPORTED_AGENT_CLS = {'BrowsingAgent', 'CodeActAgent'}
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
}
|
||||
|
||||
|
||||
def get_config(
|
||||
@@ -47,25 +56,32 @@ def get_config(
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
runtime='eventstream',
|
||||
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||
max_iterations=metadata.max_iterations,
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image='xingyaoww/od-eval-miniwob:v1.0',
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
browsergym_eval_env=env_id,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config, metadata.eval_output_dir, env_id
|
||||
)
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
) -> str:
|
||||
) -> tuple[str, BrowserOutputObservation]:
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
@@ -85,8 +101,14 @@ def initialize_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
goal = obs.content
|
||||
|
||||
# Run noop to get the initial browser observation (e.g., the page URL & content)
|
||||
action = BrowseInteractiveAction(browser_actions='noop(1000)')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
return goal
|
||||
return goal, obs
|
||||
|
||||
|
||||
def complete_runtime(
|
||||
@@ -117,7 +139,7 @@ def process_instance(
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
) -> EvalOutput:
|
||||
env_id = instance.id
|
||||
env_id = instance.instance_id
|
||||
config = get_config(metadata, env_id)
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
@@ -129,7 +151,12 @@ def process_instance(
|
||||
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
task_str = initialize_runtime(runtime)
|
||||
task_str, obs = initialize_runtime(runtime)
|
||||
|
||||
task_str += (
|
||||
f'\nInitial browser state (output of `noop(1000)`):\n{obs.get_agent_obs_text()}'
|
||||
)
|
||||
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
@@ -137,6 +164,9 @@ def process_instance(
|
||||
content=task_str
|
||||
), # take output from initialize_runtime
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -159,7 +189,7 @@ def process_instance(
|
||||
|
||||
return_val = complete_runtime(runtime)
|
||||
logger.info(f'Return value from complete_runtime: {return_val}')
|
||||
reward = max(return_val['rewards'])
|
||||
reward = max(return_val['rewards'], default=0)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-bookworm
|
||||
|
||||
|
||||
# For OpenHands agents to explore the dataset directories, please download the full benchmark [here](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/chen_8336_buckeyemail_osu_edu/EQuA6uJ3CtRHvRfZ2GiN1tYBRVJE4DSUD10MW61fr7HuSQ?e=sCBegG) and unzip it with password `scienceagentbench`.
|
||||
# **Please DO NOT redistribute the unzipped data files online.**
|
||||
# It will download a benchmark.zip file to the current directory.
|
||||
# unzip it and put the benchmark folder under evaluation/scienceagentbench/
|
||||
|
||||
RUN mkdir -p /benchmark
|
||||
COPY benchmark /benchmark
|
||||
|
||||
RUN mkdir -p /workspace
|
||||
WORKDIR /workspace
|
||||
|
||||
# pushd evaluation/scienceagentbench
|
||||
# docker build -t xingyaoww/openhands-eval-scienceagentbench .
|
||||
# popd
|
||||
@@ -0,0 +1,25 @@
|
||||
FROM mambaorg/micromamba:debian12
|
||||
|
||||
USER root
|
||||
# For https://github.com/OSU-NLP-Group/ScienceAgentBench/tree/main?tab=readme-ov-file#code-generation-with-agents
|
||||
|
||||
RUN micromamba create -n sci-agent-eval python=3.10 pip setuptools wheel
|
||||
RUN micromamba run -n sci-agent-eval pip install pip-tools
|
||||
|
||||
RUN mkdir -p /workspace
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN apt-get update && apt-get install -y git
|
||||
|
||||
RUN git clone https://github.com/OSU-NLP-Group/ScienceAgentBench.git /workspace/
|
||||
RUN git checkout 4eddc7db6449a5ade3e37285747c8b208cd54ce7
|
||||
|
||||
RUN micromamba create -n sci-agent python=3.10 pip setuptools wheel
|
||||
RUN micromamba run -n sci-agent pip install -r requirements.txt
|
||||
|
||||
# Replace all occurence of conda with micromamba under the /workspace
|
||||
RUN find ./ -type f -exec sed -i 's/conda/micromamba/g' {} \;
|
||||
|
||||
# pushd evaluation/scienceagentbench
|
||||
# docker build -t xingyaoww/openhands-eval-scienceagentbench-evaluator -f Dockerfile.evaluator .
|
||||
# popd
|
||||
@@ -0,0 +1,54 @@
|
||||
# ScienceAgentBench Evaluation with OpenHands
|
||||
|
||||
This folder contains the evaluation harness for [ScienceAgentBench](https://osu-nlp-group.github.io/ScienceAgentBench/) (paper: https://arxiv.org/abs/2410.05080).
|
||||
|
||||
## Setup Environment and LLM Configuration
|
||||
|
||||
Please follow instruction [here](../README.md#setup) to setup your local development environment and LLM.
|
||||
|
||||
## Setup ScienceAgentBench
|
||||
|
||||
To prevent benchmark data contamination, we only provide the annotation sheet on [Huggingface](https://huggingface.co/datasets/osunlp/ScienceAgentBench), which includes all necessary *inputs* to run an agent.
|
||||
|
||||
## Run Inference on ScienceAgentBench
|
||||
|
||||
```bash
|
||||
./evaluation/scienceagentbench/scripts/run_infer.sh [model_config] [git-version] [use_knowledge] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split]
|
||||
|
||||
# Example
|
||||
./evaluation/scienceagentbench/scripts/run_infer.sh llm.eval_gpt4o 0.9.3
|
||||
```
|
||||
|
||||
where `model_config` is mandatory, and the rest are optional.
|
||||
|
||||
- `model_config`, e.g. `eval_gpt4_1106_preview`, is the config group name for your
|
||||
LLM settings, as defined in your `config.toml`.
|
||||
- `git-version`, e.g. `HEAD`, is the git commit hash of the OpenHands version you would
|
||||
like to evaluate. It could also be a release tag like `0.6.2`.
|
||||
- `use_knowledge`, e.g. `true`, specifies whether allowing the agent to use expert-provided knowledge as additional input or not. By default, it is set to `false`.
|
||||
- `agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
|
||||
to `CodeActAgent`.
|
||||
- `eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances. By
|
||||
default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note:
|
||||
in order to use `eval_limit`, you must also set `agent`.
|
||||
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
|
||||
default, it is set to 30.
|
||||
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
|
||||
default, it is set to 1.
|
||||
|
||||
## Evaluate Generated Programs
|
||||
|
||||
### Extract Necessary Information from OpenHands Log
|
||||
|
||||
After the inference is completed, you may use the following command to extract necessary information from the output log for evaluation:
|
||||
|
||||
```bash
|
||||
python post_proc.py [log_fname]
|
||||
```
|
||||
- `log_fname`, e.g. `evaluation/.../output.jsonl`, is the automatically saved trajectory log of an OpenHands agent.
|
||||
|
||||
Output will be write to e.g. `evaluation/.../output.converted.jsonl`
|
||||
|
||||
### Run evaluation
|
||||
|
||||
Please follow the steps [here](https://github.com/OSU-NLP-Group/ScienceAgentBench/tree/main?tab=readme-ov-file#evaluation-of-generated-code) to evaluate the generated programs.
|
||||
@@ -0,0 +1,30 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
'log_fname',
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
fname = args.log_fname
|
||||
out_fname = args.log_fname.replace('.jsonl', '.converted.jsonl')
|
||||
|
||||
log = [json.loads(line) for line in open(fname)]
|
||||
|
||||
simple_log = [
|
||||
json.dumps(
|
||||
{
|
||||
'instance_id': ex['instance_id'],
|
||||
'instruction': ex['instruction'],
|
||||
'test_result': ex['test_result'],
|
||||
'cost': ex['metrics']['accumulated_cost'],
|
||||
}
|
||||
)
|
||||
for ex in log
|
||||
]
|
||||
|
||||
with open(out_fname, 'w+', encoding='utf-8') as f:
|
||||
f.write('\n'.join(simple_log))
|
||||
@@ -0,0 +1,291 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AppConfig,
|
||||
SandboxConfig,
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
}
|
||||
|
||||
LOCAL_DATASET_PATH = os.path.join(os.path.dirname(__file__), 'benchmark')
|
||||
|
||||
|
||||
def format_task_dict(example, use_knowledge):
|
||||
task = {
|
||||
'instance_id': example['instance_id'],
|
||||
'task_inst': example['task_inst'],
|
||||
'dataset_path': '/benchmark/datasets/'
|
||||
+ example['dataset_folder_tree'].split('\n')[0][4:],
|
||||
'dataset_folder_tree': example['dataset_folder_tree'],
|
||||
'dataset_preview': example['dataset_preview'],
|
||||
'pred_program_name': 'pred_' + example['gold_program_name'],
|
||||
}
|
||||
|
||||
if use_knowledge:
|
||||
task['task_inst'] += '\n' + str(example['domain_knowledge'])
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_config(
|
||||
metadata: EvalMetadata,
|
||||
instance_id: str,
|
||||
) -> AppConfig:
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||
max_budget_per_task=4,
|
||||
max_iterations=metadata.max_iterations,
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image='docker.io/xingyaoww/openhands-eval-scienceagentbench',
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
timeout=300,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config,
|
||||
metadata.eval_output_dir,
|
||||
instance_id,
|
||||
)
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
"""
|
||||
logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
|
||||
obs: CmdOutputObservation
|
||||
|
||||
# Set up workspace directories
|
||||
action = CmdRunAction(command='mkdir -p /workspace/pred_programs')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace/pred_results')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
dataset_name = instance['dataset_folder_tree'].split('\n')[0][4:].rstrip('/')
|
||||
|
||||
# Copy the dataset to the workspace
|
||||
dataset_dir = os.path.join(
|
||||
LOCAL_DATASET_PATH,
|
||||
'datasets',
|
||||
dataset_name,
|
||||
)
|
||||
runtime.copy_to(dataset_dir, '/workspace/benchmark/datasets', recursive=True)
|
||||
|
||||
# Check the dataset exists
|
||||
action = CmdRunAction(
|
||||
command='cd /workspace/benchmark/datasets && ls',
|
||||
keep_prompt=False,
|
||||
)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert dataset_name in obs.content
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series,
|
||||
) -> dict[str, Any]:
|
||||
"""Complete the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
If you need to do something in the sandbox to get the correctness metric after
|
||||
the agent has run, modify this function.
|
||||
"""
|
||||
logger.info(f"{'-' * 50} BEGIN Runtime Completion Fn {'-' * 50}")
|
||||
obs: CmdOutputObservation
|
||||
|
||||
test_result = {}
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(
|
||||
command=f'cat pred_programs/{instance.pred_program_name}',
|
||||
keep_prompt=False,
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
|
||||
if obs.exit_code == 0:
|
||||
test_result = {'program': obs.content}
|
||||
else:
|
||||
test_result = {'program': 'ERROR'}
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
|
||||
return test_result
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
) -> EvalOutput:
|
||||
instance_id = instance.instance_id.replace('/', '__')
|
||||
config = get_config(metadata, instance_id)
|
||||
|
||||
# Set up the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||
reset_logger_for_multiprocessing(logger, instance_id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance_id}.')
|
||||
|
||||
instruction = f"""You are an expert Python programming assistant that helps scientist users to write high-quality code to solve their tasks.
|
||||
Given a user request, you are expected to write a complete program that accomplishes the requested task and save any outputs to `/workspace/pred_results/` in the correct format.
|
||||
|
||||
Here's the user request you need to work on:
|
||||
{instance.task_inst}
|
||||
|
||||
You can access the dataset at `{instance.dataset_path}`. Here is the directory structure of the dataset:
|
||||
```
|
||||
{instance.dataset_folder_tree}
|
||||
```
|
||||
Here are some helpful previews for the dataset file(s):
|
||||
{instance.dataset_preview}
|
||||
|
||||
Please save your program as `/workspace/pred_programs/{instance.pred_program_name}`.
|
||||
Then, please run the program to check and fix any errors.
|
||||
Please do NOT run the program in the background.
|
||||
If the program uses some packages that are incompatible, please figure out alternative implementations and do NOT restart the environment.
|
||||
|
||||
"""
|
||||
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
|
||||
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
instance_id=instance.instance_id,
|
||||
instruction=instruction,
|
||||
metadata=metadata,
|
||||
history=histories,
|
||||
metrics=metrics,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
test_result=test_result,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
'--use_knowledge',
|
||||
type=str,
|
||||
default='false',
|
||||
choices=['true', 'false'],
|
||||
help='use expert-provided knowledge or not',
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
sab_dataset = load_dataset('osunlp/ScienceAgentBench', split='validation')
|
||||
|
||||
dataset_processed = []
|
||||
for example in tqdm(sab_dataset):
|
||||
dataset_processed.append(
|
||||
format_task_dict(example, args.use_knowledge == 'true')
|
||||
)
|
||||
|
||||
dataset = pd.DataFrame(dataset_processed)
|
||||
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
metadata = make_metadata(
|
||||
llm_config,
|
||||
'ScienceAgentBench',
|
||||
args.agent_cls,
|
||||
args.max_iterations,
|
||||
args.eval_note,
|
||||
args.eval_output_dir,
|
||||
)
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
dataset['instance_id'] = dataset['instance_id'].apply(str)
|
||||
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
|
||||
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
set -eo pipefail
|
||||
|
||||
source "evaluation/utils/version_control.sh"
|
||||
|
||||
MODEL_CONFIG=$1
|
||||
COMMIT_HASH=$2
|
||||
USE_KNOWLEDGE=$3
|
||||
AGENT=$4
|
||||
EVAL_LIMIT=$5
|
||||
NUM_WORKERS=$6
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
NUM_WORKERS=1
|
||||
echo "Number of workers not specified, use default $NUM_WORKERS"
|
||||
fi
|
||||
checkout_eval_branch
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
fi
|
||||
|
||||
if [ -z "$USE_KNOWLEDGE" ]; then
|
||||
echo "Use knowledge not specified, use default False"
|
||||
USE_KNOWLEDGE=false
|
||||
fi
|
||||
|
||||
get_agent_version
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "AGENT_VERSION: $AGENT_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
|
||||
COMMAND="poetry run python evaluation/scienceagentbench/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--use_knowledge $USE_KNOWLEDGE \
|
||||
--max-iterations 30 \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note $AGENT_VERSION" \
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
@@ -239,7 +239,7 @@ def process_instance(
|
||||
# Create a directory structure that matches the expected format
|
||||
# NOTE: this is a hack to make the eval report format consistent
|
||||
# with the original SWE-Bench eval script
|
||||
log_dir = os.path.join(temp_dir, 'logs', instance_id)
|
||||
log_dir = os.path.join(temp_dir, 'logs', instance_id.lower())
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
test_output_path = os.path.join(log_dir, 'test_output.txt')
|
||||
with open(test_output_path, 'w') as f:
|
||||
|
||||
@@ -20,6 +20,7 @@ from evaluation.utils.shared import (
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
@@ -40,6 +41,7 @@ from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
|
||||
USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
|
||||
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'true'
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
@@ -88,6 +90,13 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
|
||||
'5. Think about edgecases and make sure your fix handles them as well\n'
|
||||
"Your thinking should be thorough and so it's fine if it's very long.\n"
|
||||
)
|
||||
|
||||
if RUN_WITH_BROWSING:
|
||||
instruction += (
|
||||
'<IMPORTANT!>\n'
|
||||
'You SHOULD NEVER attempt to browse the web. '
|
||||
'</IMPORTANT!>\n'
|
||||
)
|
||||
return instruction
|
||||
|
||||
|
||||
@@ -101,7 +110,7 @@ def get_instance_docker_image(instance_id: str) -> str:
|
||||
image_name = image_name.replace(
|
||||
'__', '_s_'
|
||||
) # to comply with docker image naming convention
|
||||
return DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name
|
||||
return (DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name).lower()
|
||||
|
||||
|
||||
def get_config(
|
||||
@@ -142,18 +151,14 @@ def get_config(
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
if metadata.llm_config.log_completions:
|
||||
metadata.llm_config.log_completions_folder = os.path.join(
|
||||
metadata.eval_output_dir, 'llm_completions', instance['instance_id']
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config, metadata.eval_output_dir, instance['instance_id']
|
||||
)
|
||||
logger.info(
|
||||
f'Logging LLM completions for instance {instance["instance_id"]} to '
|
||||
f'{metadata.llm_config.log_completions_folder}'
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
)
|
||||
agent_config = AgentConfig(
|
||||
codeact_enable_jupyter=False,
|
||||
codeact_enable_browsing_delegate=False,
|
||||
codeact_enable_browsing=RUN_WITH_BROWSING,
|
||||
codeact_enable_llm_editor=False,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
|
||||
@@ -34,6 +34,11 @@ if [ -z "$USE_INSTANCE_IMAGE" ]; then
|
||||
USE_INSTANCE_IMAGE=true
|
||||
fi
|
||||
|
||||
if [ -z "$RUN_WITH_BROWSING" ]; then
|
||||
echo "RUN_WITH_BROWSING not specified, use default false"
|
||||
RUN_WITH_BROWSING=false
|
||||
fi
|
||||
|
||||
|
||||
if [ -z "$DATASET" ]; then
|
||||
echo "DATASET not specified, use default princeton-nlp/SWE-bench_Lite"
|
||||
@@ -47,6 +52,8 @@ fi
|
||||
|
||||
export USE_INSTANCE_IMAGE=$USE_INSTANCE_IMAGE
|
||||
echo "USE_INSTANCE_IMAGE: $USE_INSTANCE_IMAGE"
|
||||
export RUN_WITH_BROWSING=$RUN_WITH_BROWSING
|
||||
echo "RUN_WITH_BROWSING: $RUN_WITH_BROWSING"
|
||||
|
||||
get_agent_version
|
||||
|
||||
@@ -67,6 +74,10 @@ if [ "$USE_HINT_TEXT" = false ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-no-hint"
|
||||
fi
|
||||
|
||||
if [ "$RUN_WITH_BROWSING" = true ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-with-browsing"
|
||||
fi
|
||||
|
||||
if [ -n "$EXP_NAME" ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
|
||||
fi
|
||||
|
||||
@@ -411,3 +411,20 @@ def reset_logger_for_multiprocessing(
|
||||
)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
def update_llm_config_for_completions_logging(
|
||||
llm_config: LLMConfig,
|
||||
eval_output_dir: str,
|
||||
instance_id: str,
|
||||
) -> LLMConfig:
|
||||
"""Update the LLM config for logging completions."""
|
||||
if llm_config.log_completions:
|
||||
llm_config.log_completions_folder = os.path.join(
|
||||
eval_output_dir, 'llm_completions', instance_id
|
||||
)
|
||||
logger.info(
|
||||
f'Logging LLM completions for instance {instance_id} to '
|
||||
f'{llm_config.log_completions_folder}'
|
||||
)
|
||||
return llm_config
|
||||
|
||||
@@ -5,7 +5,6 @@ import { FeedbackForm } from "#/components/feedback-form";
|
||||
|
||||
describe("FeedbackForm", () => {
|
||||
const user = userEvent.setup();
|
||||
const onSubmitMock = vi.fn();
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
afterEach(() => {
|
||||
@@ -13,7 +12,7 @@ describe("FeedbackForm", () => {
|
||||
});
|
||||
|
||||
it("should render correctly", () => {
|
||||
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
|
||||
render(<FeedbackForm polarity="positive" onClose={onCloseMock} />);
|
||||
|
||||
screen.getByLabelText("Email");
|
||||
screen.getByLabelText("Private");
|
||||
@@ -24,7 +23,7 @@ describe("FeedbackForm", () => {
|
||||
});
|
||||
|
||||
it("should switch between private and public permissions", async () => {
|
||||
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
|
||||
render(<FeedbackForm polarity="positive" onClose={onCloseMock} />);
|
||||
const privateRadio = screen.getByLabelText("Private");
|
||||
const publicRadio = screen.getByLabelText("Public");
|
||||
|
||||
@@ -40,69 +39,11 @@ describe("FeedbackForm", () => {
|
||||
expect(publicRadio).not.toBeChecked();
|
||||
});
|
||||
|
||||
it("should call onSubmit when the form is submitted", async () => {
|
||||
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
|
||||
const email = screen.getByLabelText("Email");
|
||||
|
||||
await user.type(email, "test@test.test");
|
||||
await user.click(screen.getByRole("button", { name: "Submit" }));
|
||||
|
||||
expect(onSubmitMock).toHaveBeenCalledWith("private", "test@test.test"); // private is the default value
|
||||
});
|
||||
|
||||
it("should not call onSubmit when the email is invalid", async () => {
|
||||
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
|
||||
const email = screen.getByLabelText("Email");
|
||||
const submitButton = screen.getByRole("button", { name: "Submit" });
|
||||
|
||||
await user.click(submitButton);
|
||||
|
||||
expect(onSubmitMock).not.toHaveBeenCalled();
|
||||
|
||||
await user.type(email, "test");
|
||||
await user.click(submitButton);
|
||||
|
||||
expect(onSubmitMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should submit public permissions when the public radio is checked", async () => {
|
||||
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
|
||||
const email = screen.getByLabelText("Email");
|
||||
const publicRadio = screen.getByLabelText("Public");
|
||||
|
||||
await user.type(email, "test@test.test");
|
||||
await user.click(publicRadio);
|
||||
await user.click(screen.getByRole("button", { name: "Submit" }));
|
||||
|
||||
expect(onSubmitMock).toHaveBeenCalledWith("public", "test@test.test");
|
||||
});
|
||||
|
||||
it("should call onClose when the close button is clicked", async () => {
|
||||
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
|
||||
render(<FeedbackForm polarity="positive" onClose={onCloseMock} />);
|
||||
await user.click(screen.getByRole("button", { name: "Cancel" }));
|
||||
|
||||
expect(onSubmitMock).not.toHaveBeenCalled();
|
||||
expect(onCloseMock).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should disable the buttons if isSubmitting is true", () => {
|
||||
const { rerender } = render(
|
||||
<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />,
|
||||
);
|
||||
const submitButton = screen.getByRole("button", { name: "Submit" });
|
||||
const cancelButton = screen.getByRole("button", { name: "Cancel" });
|
||||
|
||||
expect(submitButton).not.toBeDisabled();
|
||||
expect(cancelButton).not.toBeDisabled();
|
||||
|
||||
rerender(
|
||||
<FeedbackForm
|
||||
onSubmit={onSubmitMock}
|
||||
onClose={onCloseMock}
|
||||
isSubmitting
|
||||
/>,
|
||||
);
|
||||
expect(submitButton).toBeDisabled();
|
||||
expect(cancelButton).toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,13 +16,16 @@ vi.mock("../../services/fileService", async () => ({
|
||||
}));
|
||||
|
||||
const renderFileExplorerWithRunningAgentState = () =>
|
||||
renderWithProviders(<FileExplorer error={null} />, {
|
||||
preloadedState: {
|
||||
agent: {
|
||||
curAgentState: AgentState.RUNNING,
|
||||
renderWithProviders(
|
||||
<FileExplorer error={null} isOpen onToggle={() => {}} />,
|
||||
{
|
||||
preloadedState: {
|
||||
agent: {
|
||||
curAgentState: AgentState.RUNNING,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
);
|
||||
|
||||
describe.skip("FileExplorer", () => {
|
||||
afterEach(() => {
|
||||
|
||||
Generated
+43
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.11.0",
|
||||
"version": "0.12.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.11.0",
|
||||
"version": "0.12.0",
|
||||
"dependencies": {
|
||||
"@monaco-editor/react": "^4.6.0",
|
||||
"@nextui-org/react": "^2.4.8",
|
||||
@@ -26,6 +26,7 @@
|
||||
"isbot": "^5.1.17",
|
||||
"jose": "^5.9.4",
|
||||
"monaco-editor": "^0.52.0",
|
||||
"posthog-js": "^1.176.0",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-highlight": "^0.15.0",
|
||||
@@ -7864,6 +7865,16 @@
|
||||
"node": ">=6.6.0"
|
||||
}
|
||||
},
|
||||
"node_modules/core-js": {
|
||||
"version": "3.38.1",
|
||||
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.38.1.tgz",
|
||||
"integrity": "sha512-OP35aUorbU3Zvlx7pjsFdu1rGNnD4pgw/CWoYzRY3t2EzoVT7shKHY1dlAy3f41cGIO7ZDPQimhGFTlEYkG/Hw==",
|
||||
"hasInstallScript": true,
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/core-js"
|
||||
}
|
||||
},
|
||||
"node_modules/core-util-is": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz",
|
||||
@@ -9666,6 +9677,11 @@
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/fflate": {
|
||||
"version": "0.4.8",
|
||||
"resolved": "https://registry.npmjs.org/fflate/-/fflate-0.4.8.tgz",
|
||||
"integrity": "sha512-FJqqoDBR00Mdj9ppamLa/Y7vxm+PRmNWA67N846RvsoYVMKB4q3y/de5PA7gUmRMYK/8CMz2GDZQmCRN1wBcWA=="
|
||||
},
|
||||
"node_modules/file-entry-cache": {
|
||||
"version": "6.0.1",
|
||||
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz",
|
||||
@@ -19653,6 +19669,31 @@
|
||||
"resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz",
|
||||
"integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ=="
|
||||
},
|
||||
"node_modules/posthog-js": {
|
||||
"version": "1.176.0",
|
||||
"resolved": "https://registry.npmjs.org/posthog-js/-/posthog-js-1.176.0.tgz",
|
||||
"integrity": "sha512-T5XKNtRzp7q6CGb7Vc7wAI76rWap9fiuDUPxPsyPBPDkreKya91x9RIsSapAVFafwD1AEin1QMczCmt9Le9BWw==",
|
||||
"dependencies": {
|
||||
"core-js": "^3.38.1",
|
||||
"fflate": "^0.4.8",
|
||||
"preact": "^10.19.3",
|
||||
"web-vitals": "^4.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/posthog-js/node_modules/web-vitals": {
|
||||
"version": "4.2.4",
|
||||
"resolved": "https://registry.npmjs.org/web-vitals/-/web-vitals-4.2.4.tgz",
|
||||
"integrity": "sha512-r4DIlprAGwJ7YM11VZp4R884m0Vmgr6EAKe3P+kO0PPj3Unqyvv59rczf6UiGcb9Z8QxZVcqKNwv/g0WNdWwsw=="
|
||||
},
|
||||
"node_modules/preact": {
|
||||
"version": "10.24.3",
|
||||
"resolved": "https://registry.npmjs.org/preact/-/preact-10.24.3.tgz",
|
||||
"integrity": "sha512-Z2dPnBnMUfyQfSQ+GBdsGa16hz35YmLmtTLhM169uW944hYL6xzTYkJjC07j+Wosz733pMWx0fgON3JNw1jJQA==",
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/preact"
|
||||
}
|
||||
},
|
||||
"node_modules/prelude-ls": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.11.0",
|
||||
"version": "0.12.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"engines": {
|
||||
@@ -25,6 +25,7 @@
|
||||
"isbot": "^5.1.17",
|
||||
"jose": "^5.9.4",
|
||||
"monaco-editor": "^0.52.0",
|
||||
"posthog-js": "^1.176.0",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-highlight": "^0.15.0",
|
||||
|
||||
+28
-107
@@ -1,4 +1,4 @@
|
||||
import { getValidFallbackHost } from "#/utils/get-valid-fallback-host";
|
||||
import { request } from "#/services/api";
|
||||
import {
|
||||
SaveFileSuccessResponse,
|
||||
FileUploadSuccessResponse,
|
||||
@@ -9,36 +9,13 @@ import {
|
||||
GetConfigResponse,
|
||||
} from "./open-hands.types";
|
||||
|
||||
/**
|
||||
* Generate the base URL of the OpenHands API
|
||||
* @returns Base URL of the OpenHands API
|
||||
*/
|
||||
const generateBaseURL = () => {
|
||||
const fallback = getValidFallbackHost();
|
||||
const baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || fallback;
|
||||
|
||||
if (typeof window === "undefined") {
|
||||
return `http://${baseUrl}`;
|
||||
}
|
||||
return `${window.location.protocol}//${baseUrl}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Class to interact with the OpenHands API
|
||||
*/
|
||||
class OpenHands {
|
||||
/**
|
||||
* Base URL of the OpenHands API
|
||||
*/
|
||||
static BASE_URL = generateBaseURL();
|
||||
|
||||
/**
|
||||
* Retrieve the list of models available
|
||||
* @returns List of models available
|
||||
*/
|
||||
static async getModels(): Promise<string[]> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/options/models`);
|
||||
return response.json();
|
||||
return request("/api/options/models");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -46,8 +23,7 @@ class OpenHands {
|
||||
* @returns List of agents available
|
||||
*/
|
||||
static async getAgents(): Promise<string[]> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/options/agents`);
|
||||
return response.json();
|
||||
return request(`/api/options/agents`);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -55,178 +31,123 @@ class OpenHands {
|
||||
* @returns List of security analyzers available
|
||||
*/
|
||||
static async getSecurityAnalyzers(): Promise<string[]> {
|
||||
const response = await fetch(
|
||||
`${OpenHands.BASE_URL}/api/options/security-analyzers`,
|
||||
);
|
||||
return response.json();
|
||||
return request(`/api/options/security-analyzers`);
|
||||
}
|
||||
|
||||
static async getConfig(): Promise<GetConfigResponse> {
|
||||
const response = await fetch("config.json", {
|
||||
headers: {
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
});
|
||||
return response.json();
|
||||
return request("/config.json");
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the list of files available in the workspace
|
||||
* @param token User token provided by the server
|
||||
* @param path Path to list files from
|
||||
* @returns List of files available in the given path. If path is not provided, it lists all the files in the workspace
|
||||
*/
|
||||
static async getFiles(token: string, path?: string): Promise<string[]> {
|
||||
const url = new URL(`${OpenHands.BASE_URL}/api/list-files`);
|
||||
if (path) url.searchParams.append("path", path);
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
return response.json();
|
||||
static async getFiles(path?: string): Promise<string[]> {
|
||||
let url = "/api/list-files";
|
||||
if (path) url += `?path=${encodeURIComponent(path)}`;
|
||||
return request(url);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the content of a file
|
||||
* @param token User token provided by the server
|
||||
* @param path Full path of the file to retrieve
|
||||
* @returns Content of the file
|
||||
*/
|
||||
static async getFile(token: string, path: string): Promise<string> {
|
||||
const url = new URL(`${OpenHands.BASE_URL}/api/select-file`);
|
||||
url.searchParams.append("file", path);
|
||||
const response = await fetch(url.toString(), {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
static async getFile(path: string): Promise<string> {
|
||||
const url = `/api/select-file?file=${encodeURIComponent(path)}`;
|
||||
const data = await request(url);
|
||||
return data.code;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the content of a file
|
||||
* @param token User token provided by the server
|
||||
* @param path Full path of the file to save
|
||||
* @param content Content to save in the file
|
||||
* @returns Success message or error message
|
||||
*/
|
||||
static async saveFile(
|
||||
token: string,
|
||||
path: string,
|
||||
content: string,
|
||||
): Promise<SaveFileSuccessResponse | ErrorResponse> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/save-file`, {
|
||||
return request(`/api/save-file`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ filePath: path, content }),
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload a file to the workspace
|
||||
* @param token User token provided by the server
|
||||
* @param file File to upload
|
||||
* @returns Success message or error message
|
||||
*/
|
||||
static async uploadFiles(
|
||||
token: string,
|
||||
file: File[],
|
||||
): Promise<FileUploadSuccessResponse | ErrorResponse> {
|
||||
const formData = new FormData();
|
||||
file.forEach((f) => formData.append("files", f));
|
||||
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/upload-files`, {
|
||||
return request(`/api/upload-files`, {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the blob of the workspace zip
|
||||
* @param token User token provided by the server
|
||||
* @returns Blob of the workspace zip
|
||||
*/
|
||||
static async getWorkspaceZip(token: string): Promise<Blob> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/zip-directory`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
static async getWorkspaceZip(): Promise<Blob> {
|
||||
const response = await request(`/api/zip-directory`, {}, false, true);
|
||||
return response.blob();
|
||||
}
|
||||
|
||||
/**
|
||||
* Send feedback to the server
|
||||
* @param token User token provided by the server
|
||||
* @param data Feedback data
|
||||
* @returns The stored feedback data
|
||||
*/
|
||||
static async sendFeedback(
|
||||
token: string,
|
||||
data: Feedback,
|
||||
): Promise<FeedbackResponse> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/submit-feedback`, {
|
||||
static async submitFeedback(data: Feedback): Promise<FeedbackResponse> {
|
||||
return request(`/api/submit-feedback`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(data),
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the GitHub access token
|
||||
* @param code Code provided by GitHub
|
||||
* @returns GitHub access token
|
||||
*/
|
||||
static async getGitHubAccessToken(
|
||||
code: string,
|
||||
): Promise<GitHubAccessTokenResponse> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/github/callback`, {
|
||||
return request(`/api/github/callback`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ code }),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the user is authenticated
|
||||
* @param login The user's GitHub login handle
|
||||
* @returns Whether the user is authenticated
|
||||
* Authenticate with GitHub token
|
||||
* @returns Response with authentication status and user info if successful
|
||||
*/
|
||||
static async isAuthenticated(login: string): Promise<boolean> {
|
||||
const response = await fetch(`${OpenHands.BASE_URL}/api/authenticate`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ login }),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
static async authenticate(): Promise<Response> {
|
||||
return request(
|
||||
`/api/authenticate`,
|
||||
{
|
||||
method: "POST",
|
||||
},
|
||||
});
|
||||
|
||||
return response.status === 200;
|
||||
true,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,11 +27,16 @@ export interface GitHubAccessTokenResponse {
|
||||
access_token: string;
|
||||
}
|
||||
|
||||
export interface AuthenticationResponse {
|
||||
message: string;
|
||||
login?: string; // Only present when allow list is enabled
|
||||
}
|
||||
|
||||
export interface Feedback {
|
||||
version: string;
|
||||
email: string;
|
||||
token: string;
|
||||
feedback: "positive" | "negative";
|
||||
polarity: "positive" | "negative";
|
||||
permissions: "public" | "private";
|
||||
trajectory: unknown[];
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import { useFetcher } from "@remix-run/react";
|
||||
import { ModalBackdrop } from "./modals/modal-backdrop";
|
||||
import ModalBody from "./modals/ModalBody";
|
||||
import ModalButton from "./buttons/ModalButton";
|
||||
import {
|
||||
BaseModalTitle,
|
||||
BaseModalDescription,
|
||||
} from "./modals/confirmation-modals/BaseModal";
|
||||
|
||||
export function AnalyticsConsentFormModal() {
|
||||
const fetcher = useFetcher({ key: "set-consent" });
|
||||
|
||||
return (
|
||||
<ModalBackdrop>
|
||||
<fetcher.Form
|
||||
method="POST"
|
||||
action="/set-consent"
|
||||
className="flex flex-col gap-2"
|
||||
>
|
||||
<ModalBody>
|
||||
<BaseModalTitle title="Your Privacy Preferences" />
|
||||
<BaseModalDescription>
|
||||
We use tools to understand how our application is used to improve
|
||||
your experience. You can enable or disable analytics. Your
|
||||
preferences will be stored and can be updated anytime.
|
||||
</BaseModalDescription>
|
||||
|
||||
<label className="flex gap-2 items-center self-start">
|
||||
<input name="analytics" type="checkbox" defaultChecked />
|
||||
Send anonymous usage data
|
||||
</label>
|
||||
|
||||
<ModalButton
|
||||
type="submit"
|
||||
text="Confirm Preferences"
|
||||
className="bg-primary text-white w-full hover:opacity-80"
|
||||
/>
|
||||
</ModalBody>
|
||||
</fetcher.Form>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import React from "react";
|
||||
import { useFetcher } from "@remix-run/react";
|
||||
import { useSocket } from "#/context/socket";
|
||||
import { convertImageToBase64 } from "#/utils/convert-image-to-base-64";
|
||||
import { ChatMessage } from "./chat-message";
|
||||
@@ -13,10 +12,6 @@ import { RootState } from "#/store";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { generateAgentStateChangeEvent } from "#/services/agentStateService";
|
||||
import { FeedbackModal } from "./feedback-modal";
|
||||
import { Feedback } from "#/api/open-hands.types";
|
||||
import { getToken } from "#/services/auth";
|
||||
import { removeApiKey, removeUnwantedKeys } from "#/utils/utils";
|
||||
import { clientAction } from "#/routes/submit-feedback";
|
||||
import { useScrollToBottom } from "#/hooks/useScrollToBottom";
|
||||
import TypingIndicator from "./chat/TypingIndicator";
|
||||
import ConfirmationButtons from "./chat/ConfirmationButtons";
|
||||
@@ -24,16 +19,13 @@ import { ErrorMessage } from "./error-message";
|
||||
import { ContinueButton } from "./continue-button";
|
||||
import { ScrollToBottomButton } from "./scroll-to-bottom-button";
|
||||
|
||||
const FEEDBACK_VERSION = "1.0";
|
||||
|
||||
const isErrorMessage = (
|
||||
message: Message | ErrorMessage,
|
||||
): message is ErrorMessage => "error" in message;
|
||||
|
||||
export function ChatInterface() {
|
||||
const { send, events } = useSocket();
|
||||
const { send } = useSocket();
|
||||
const dispatch = useDispatch();
|
||||
const fetcher = useFetcher<typeof clientAction>({ key: "feedback" });
|
||||
const scrollRef = React.useRef<HTMLDivElement>(null);
|
||||
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
|
||||
useScrollToBottom(scrollRef);
|
||||
@@ -44,7 +36,6 @@ export function ChatInterface() {
|
||||
const [feedbackPolarity, setFeedbackPolarity] = React.useState<
|
||||
"positive" | "negative"
|
||||
>("positive");
|
||||
const [feedbackShared, setFeedbackShared] = React.useState(0);
|
||||
const [feedbackModalIsOpen, setFeedbackModalIsOpen] = React.useState(false);
|
||||
|
||||
const handleSendMessage = async (content: string, files: File[]) => {
|
||||
@@ -71,30 +62,6 @@ export function ChatInterface() {
|
||||
setFeedbackPolarity(polarity);
|
||||
};
|
||||
|
||||
const handleSubmitFeedback = (
|
||||
permissions: "private" | "public",
|
||||
email: string,
|
||||
) => {
|
||||
const feedback: Feedback = {
|
||||
version: FEEDBACK_VERSION,
|
||||
feedback: feedbackPolarity,
|
||||
email,
|
||||
permissions,
|
||||
token: getToken(),
|
||||
trajectory: removeApiKey(removeUnwantedKeys(events)),
|
||||
};
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append("feedback", JSON.stringify(feedback));
|
||||
|
||||
fetcher.submit(formData, {
|
||||
action: "/submit-feedback",
|
||||
method: "POST",
|
||||
});
|
||||
|
||||
setFeedbackShared(messages.length);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="h-full flex flex-col justify-between">
|
||||
<div
|
||||
@@ -130,16 +97,14 @@ export function ChatInterface() {
|
||||
|
||||
<div className="flex flex-col gap-[6px] px-4 pb-4">
|
||||
<div className="flex justify-between relative">
|
||||
{feedbackShared !== messages.length && messages.length > 3 && (
|
||||
<FeedbackActions
|
||||
onPositiveFeedback={() =>
|
||||
onClickShareFeedbackActionButton("positive")
|
||||
}
|
||||
onNegativeFeedback={() =>
|
||||
onClickShareFeedbackActionButton("negative")
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<FeedbackActions
|
||||
onPositiveFeedback={() =>
|
||||
onClickShareFeedbackActionButton("positive")
|
||||
}
|
||||
onNegativeFeedback={() =>
|
||||
onClickShareFeedbackActionButton("negative")
|
||||
}
|
||||
/>
|
||||
<div className="absolute left-1/2 transform -translate-x-1/2 bottom-0">
|
||||
{messages.length > 2 &&
|
||||
curAgentState === AgentState.AWAITING_USER_INPUT && (
|
||||
@@ -163,9 +128,8 @@ export function ChatInterface() {
|
||||
|
||||
<FeedbackModal
|
||||
isOpen={feedbackModalIsOpen}
|
||||
isSubmitting={fetcher.state === "submitting"}
|
||||
onClose={() => setFeedbackModalIsOpen(false)}
|
||||
onSubmit={handleSubmitFeedback}
|
||||
polarity={feedbackPolarity}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,27 +1,81 @@
|
||||
import React from "react";
|
||||
import hotToast from "react-hot-toast";
|
||||
import ModalButton from "./buttons/ModalButton";
|
||||
import { Feedback } from "#/api/open-hands.types";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
const FEEDBACK_VERSION = "1.0";
|
||||
const VIEWER_PAGE = "https://www.all-hands.dev/share";
|
||||
|
||||
interface FeedbackFormProps {
|
||||
onSubmit: (permissions: "private" | "public", email: string) => void;
|
||||
onClose: () => void;
|
||||
isSubmitting?: boolean;
|
||||
polarity: "positive" | "negative";
|
||||
}
|
||||
|
||||
export function FeedbackForm({
|
||||
onSubmit,
|
||||
onClose,
|
||||
isSubmitting,
|
||||
}: FeedbackFormProps) {
|
||||
const handleSubmit = (event: React.FormEvent<HTMLFormElement>) => {
|
||||
export function FeedbackForm({ onClose, polarity }: FeedbackFormProps) {
|
||||
const [isSubmitting, setIsSubmitting] = React.useState(false);
|
||||
|
||||
const copiedToClipboardToast = () => {
|
||||
hotToast("Password copied to clipboard", {
|
||||
icon: "📋",
|
||||
position: "bottom-right",
|
||||
});
|
||||
};
|
||||
|
||||
const onPressToast = (password: string) => {
|
||||
navigator.clipboard.writeText(password);
|
||||
copiedToClipboardToast();
|
||||
};
|
||||
|
||||
const shareFeedbackToast = (
|
||||
message: string,
|
||||
link: string,
|
||||
password: string,
|
||||
) => {
|
||||
hotToast(
|
||||
<div className="flex flex-col gap-1">
|
||||
<span>{message}</span>
|
||||
<a
|
||||
data-testid="toast-share-url"
|
||||
className="text-blue-500 underline"
|
||||
onClick={() => onPressToast(password)}
|
||||
href={link}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
Go to shared feedback
|
||||
</a>
|
||||
<span onClick={() => onPressToast(password)} className="cursor-pointer">
|
||||
Password: {password} <span className="text-gray-500">(copy)</span>
|
||||
</span>
|
||||
</div>,
|
||||
{ duration: 10000 },
|
||||
);
|
||||
};
|
||||
|
||||
const handleSubmit = async (event: React.FormEvent<HTMLFormElement>) => {
|
||||
event?.preventDefault();
|
||||
const formData = new FormData(event.currentTarget);
|
||||
setIsSubmitting(true);
|
||||
|
||||
const email = formData.get("email")?.toString();
|
||||
const permissions = formData.get("permissions")?.toString() as
|
||||
| "private"
|
||||
| "public"
|
||||
| undefined;
|
||||
const email = formData.get("email")?.toString() || "";
|
||||
const permissions = (formData.get("permissions")?.toString() ||
|
||||
"private") as "private" | "public";
|
||||
|
||||
if (email) onSubmit(permissions || "private", email);
|
||||
const feedback: Feedback = {
|
||||
version: FEEDBACK_VERSION,
|
||||
email,
|
||||
polarity,
|
||||
permissions,
|
||||
trajectory: [],
|
||||
token: "",
|
||||
};
|
||||
|
||||
const response = await OpenHands.submitFeedback(feedback);
|
||||
const { message, feedback_id, password } = response.body; // eslint-disable-line
|
||||
const link = `${VIEWER_PAGE}?share_id=${feedback_id}`;
|
||||
shareFeedbackToast(message, link, password);
|
||||
setIsSubmitting(false);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import React from "react";
|
||||
import hotToast, { toast } from "react-hot-toast";
|
||||
import { useFetcher } from "@remix-run/react";
|
||||
import { FeedbackForm } from "./feedback-form";
|
||||
import {
|
||||
BaseModalTitle,
|
||||
@@ -8,82 +6,18 @@ import {
|
||||
} from "./modals/confirmation-modals/BaseModal";
|
||||
import { ModalBackdrop } from "./modals/modal-backdrop";
|
||||
import ModalBody from "./modals/ModalBody";
|
||||
import { clientAction } from "#/routes/submit-feedback";
|
||||
|
||||
interface FeedbackModalProps {
|
||||
onSubmit: (permissions: "private" | "public", email: string) => void;
|
||||
onClose: () => void;
|
||||
isOpen: boolean;
|
||||
isSubmitting?: boolean;
|
||||
polarity: "positive" | "negative";
|
||||
}
|
||||
|
||||
export function FeedbackModal({
|
||||
onSubmit,
|
||||
onClose,
|
||||
isOpen,
|
||||
isSubmitting,
|
||||
polarity,
|
||||
}: FeedbackModalProps) {
|
||||
const fetcher = useFetcher<typeof clientAction>({ key: "feedback" });
|
||||
const isInitialRender = React.useRef(true);
|
||||
|
||||
const copiedToClipboardToast = () => {
|
||||
hotToast("Password copied to clipboard", {
|
||||
icon: "📋",
|
||||
position: "bottom-right",
|
||||
});
|
||||
};
|
||||
|
||||
const onPressToast = (password: string) => {
|
||||
navigator.clipboard.writeText(password);
|
||||
copiedToClipboardToast();
|
||||
};
|
||||
|
||||
const shareFeedbackToast = (
|
||||
message: string,
|
||||
link: string,
|
||||
password: string,
|
||||
) => {
|
||||
hotToast(
|
||||
<div className="flex flex-col gap-1">
|
||||
<span>{message}</span>
|
||||
<a
|
||||
data-testid="toast-share-url"
|
||||
className="text-blue-500 underline"
|
||||
onClick={() => onPressToast(password)}
|
||||
href={link}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
Go to shared feedback
|
||||
</a>
|
||||
<span onClick={() => onPressToast(password)} className="cursor-pointer">
|
||||
Password: {password} <span className="text-gray-500">(copy)</span>
|
||||
</span>
|
||||
</div>,
|
||||
{ duration: 10000 },
|
||||
);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
if (isInitialRender.current) {
|
||||
isInitialRender.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle feedback submission
|
||||
if (fetcher.state === "idle" && fetcher.data) {
|
||||
if (!fetcher.data.success) {
|
||||
toast.error("Error submitting feedback");
|
||||
} else if (fetcher.data.data) {
|
||||
const { data } = fetcher.data;
|
||||
const { message, link, password } = data;
|
||||
shareFeedbackToast(message, link, password);
|
||||
}
|
||||
|
||||
onClose();
|
||||
}
|
||||
}, [fetcher.state, fetcher.data?.success]);
|
||||
|
||||
if (!isOpen) return null;
|
||||
|
||||
return (
|
||||
@@ -91,11 +25,7 @@ export function FeedbackModal({
|
||||
<ModalBody>
|
||||
<BaseModalTitle title="Feedback" />
|
||||
<BaseModalDescription description="To help us improve, we collect feedback from your interactions to improve our prompts. By submitting this form, you consent to us collecting this data." />
|
||||
<FeedbackForm
|
||||
onSubmit={onSubmit}
|
||||
onClose={onClose}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
<FeedbackForm onClose={onClose} polarity={polarity} />
|
||||
</ModalBody>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
|
||||
@@ -91,14 +91,15 @@ function ExplorerActions({
|
||||
}
|
||||
|
||||
interface FileExplorerProps {
|
||||
isOpen: boolean;
|
||||
onToggle: () => void;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
function FileExplorer({ error }: FileExplorerProps) {
|
||||
function FileExplorer({ error, isOpen, onToggle }: FileExplorerProps) {
|
||||
const { revalidate } = useRevalidator();
|
||||
|
||||
const { paths, setPaths } = useFiles();
|
||||
const [isHidden, setIsHidden] = React.useState(false);
|
||||
const [isDragging, setIsDragging] = React.useState(false);
|
||||
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
@@ -117,52 +118,47 @@ function FileExplorer({ error }: FileExplorerProps) {
|
||||
return;
|
||||
}
|
||||
dispatch(setRefreshID(Math.random()));
|
||||
// TODO: Get token from data loader
|
||||
const token = localStorage.getItem("token");
|
||||
if (token) OpenHands.getFiles(token).then(setPaths);
|
||||
OpenHands.getFiles().then(setPaths);
|
||||
revalidate();
|
||||
};
|
||||
|
||||
const uploadFileData = async (files: FileList) => {
|
||||
try {
|
||||
const token = localStorage.getItem("token");
|
||||
if (token) {
|
||||
const result = await OpenHands.uploadFiles(token, Array.from(files));
|
||||
const result = await OpenHands.uploadFiles(Array.from(files));
|
||||
|
||||
if (isOpenHandsErrorResponse(result)) {
|
||||
// Handle error response
|
||||
toast.error(
|
||||
`upload-error-${new Date().getTime()}`,
|
||||
result.error || t(I18nKey.EXPLORER$UPLOAD_ERROR_MESSAGE),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const uploadedCount = result.uploaded_files.length;
|
||||
const skippedCount = result.skipped_files.length;
|
||||
|
||||
if (uploadedCount > 0) {
|
||||
toast.success(
|
||||
`upload-success-${new Date().getTime()}`,
|
||||
t(I18nKey.EXPLORER$UPLOAD_SUCCESS_MESSAGE, {
|
||||
count: uploadedCount,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (skippedCount > 0) {
|
||||
const message = t(I18nKey.EXPLORER$UPLOAD_PARTIAL_SUCCESS_MESSAGE, {
|
||||
count: skippedCount,
|
||||
});
|
||||
toast.info(message);
|
||||
}
|
||||
|
||||
if (uploadedCount === 0 && skippedCount === 0) {
|
||||
toast.info(t(I18nKey.EXPLORER$NO_FILES_UPLOADED_MESSAGE));
|
||||
}
|
||||
|
||||
refreshWorkspace();
|
||||
if (isOpenHandsErrorResponse(result)) {
|
||||
// Handle error response
|
||||
toast.error(
|
||||
`upload-error-${new Date().getTime()}`,
|
||||
result.error || t(I18nKey.EXPLORER$UPLOAD_ERROR_MESSAGE),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const uploadedCount = result.uploaded_files.length;
|
||||
const skippedCount = result.skipped_files.length;
|
||||
|
||||
if (uploadedCount > 0) {
|
||||
toast.success(
|
||||
`upload-success-${new Date().getTime()}`,
|
||||
t(I18nKey.EXPLORER$UPLOAD_SUCCESS_MESSAGE, {
|
||||
count: uploadedCount,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (skippedCount > 0) {
|
||||
const message = t(I18nKey.EXPLORER$UPLOAD_PARTIAL_SUCCESS_MESSAGE, {
|
||||
count: skippedCount,
|
||||
});
|
||||
toast.info(message);
|
||||
}
|
||||
|
||||
if (uploadedCount === 0 && skippedCount === 0) {
|
||||
toast.info(t(I18nKey.EXPLORER$NO_FILES_UPLOADED_MESSAGE));
|
||||
}
|
||||
|
||||
refreshWorkspace();
|
||||
} catch (e) {
|
||||
// Handle unexpected errors (network issues, etc.)
|
||||
toast.error(
|
||||
@@ -211,7 +207,7 @@ function FileExplorer({ error }: FileExplorerProps) {
|
||||
<div
|
||||
className={twMerge(
|
||||
"bg-neutral-800 h-full border-r-1 border-r-neutral-600 flex flex-col",
|
||||
isHidden ? "w-12" : "w-60",
|
||||
!isOpen ? "w-12" : "w-60",
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col relative h-full px-3 py-2">
|
||||
@@ -219,17 +215,17 @@ function FileExplorer({ error }: FileExplorerProps) {
|
||||
<div
|
||||
className={twMerge(
|
||||
"flex items-center",
|
||||
isHidden ? "justify-center" : "justify-between",
|
||||
!isOpen ? "justify-center" : "justify-between",
|
||||
)}
|
||||
>
|
||||
{!isHidden && (
|
||||
{isOpen && (
|
||||
<div className="text-neutral-300 font-bold text-sm">
|
||||
{t(I18nKey.EXPLORER$LABEL_WORKSPACE)}
|
||||
</div>
|
||||
)}
|
||||
<ExplorerActions
|
||||
isHidden={isHidden}
|
||||
toggleHidden={() => setIsHidden((prev) => !prev)}
|
||||
isHidden={!isOpen}
|
||||
toggleHidden={onToggle}
|
||||
onRefresh={refreshWorkspace}
|
||||
onUpload={selectFileInput}
|
||||
/>
|
||||
@@ -237,7 +233,7 @@ function FileExplorer({ error }: FileExplorerProps) {
|
||||
</div>
|
||||
{!error && (
|
||||
<div className="overflow-auto flex-grow">
|
||||
<div style={{ display: isHidden ? "none" : "block" }}>
|
||||
<div style={{ display: !isOpen ? "none" : "block" }}>
|
||||
<ExplorerTree files={paths} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -59,14 +59,11 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
|
||||
return;
|
||||
}
|
||||
|
||||
const token = localStorage.getItem("token");
|
||||
if (token) {
|
||||
try {
|
||||
const newChildren = await OpenHands.getFiles(token, path);
|
||||
setChildren(newChildren);
|
||||
} catch (error) {
|
||||
toast.error("Failed to fetch files");
|
||||
}
|
||||
try {
|
||||
const newChildren = await OpenHands.getFiles(path);
|
||||
setChildren(newChildren);
|
||||
} catch (error) {
|
||||
toast.error("Failed to fetch files");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -77,15 +74,13 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
|
||||
}, [refreshID, isOpen]);
|
||||
|
||||
const handleClick = async () => {
|
||||
const token = localStorage.getItem("token");
|
||||
|
||||
if (isDirectory) {
|
||||
setIsOpen((prev) => !prev);
|
||||
} else if (token) {
|
||||
} else {
|
||||
const code = modifiedFiles[path] || files[path];
|
||||
|
||||
try {
|
||||
const fetchedCode = await OpenHands.getFile(token, path);
|
||||
const fetchedCode = await OpenHands.getFile(path);
|
||||
setSelectedPath(path);
|
||||
if (!code || fetchedCode !== files[path]) {
|
||||
setFileContent(path, fetchedCode);
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
import React from "react";
|
||||
import {
|
||||
isGitHubErrorReponse,
|
||||
retrieveAllGitHubUserRepositories,
|
||||
} from "#/api/github";
|
||||
import { SuggestionBox } from "#/routes/_oh._index/suggestion-box";
|
||||
import { ConnectToGitHubModal } from "./modals/connect-to-github-modal";
|
||||
import { ModalBackdrop } from "./modals/modal-backdrop";
|
||||
import { GitHubRepositorySelector } from "#/routes/_oh._index/github-repo-selector";
|
||||
import ModalButton from "./buttons/ModalButton";
|
||||
import GitHubLogo from "#/assets/branding/github-logo.svg?react";
|
||||
|
||||
interface GitHubAuthProps {
|
||||
onConnectToGitHub: () => void;
|
||||
repositories: GitHubRepository[];
|
||||
isLoggedIn: boolean;
|
||||
}
|
||||
|
||||
function GitHubAuth({
|
||||
onConnectToGitHub,
|
||||
repositories,
|
||||
isLoggedIn,
|
||||
}: GitHubAuthProps) {
|
||||
if (isLoggedIn) {
|
||||
return <GitHubRepositorySelector repositories={repositories} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<ModalButton
|
||||
text="Connect to GitHub"
|
||||
icon={<GitHubLogo width={20} height={20} />}
|
||||
className="bg-[#791B80] w-full"
|
||||
onClick={onConnectToGitHub}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
interface GitHubRepositoriesSuggestionBoxProps {
|
||||
repositories: Awaited<
|
||||
ReturnType<typeof retrieveAllGitHubUserRepositories>
|
||||
> | null;
|
||||
gitHubAuthUrl: string | null;
|
||||
user: GitHubErrorReponse | GitHubUser | null;
|
||||
}
|
||||
|
||||
export function GitHubRepositoriesSuggestionBox({
|
||||
repositories,
|
||||
gitHubAuthUrl,
|
||||
user,
|
||||
}: GitHubRepositoriesSuggestionBoxProps) {
|
||||
const [connectToGitHubModalOpen, setConnectToGitHubModalOpen] =
|
||||
React.useState(false);
|
||||
|
||||
const handleConnectToGitHub = () => {
|
||||
if (gitHubAuthUrl) {
|
||||
window.location.href = gitHubAuthUrl;
|
||||
} else {
|
||||
setConnectToGitHubModalOpen(true);
|
||||
}
|
||||
};
|
||||
|
||||
if (isGitHubErrorReponse(repositories)) {
|
||||
return (
|
||||
<SuggestionBox
|
||||
title="Error Fetching Repositories"
|
||||
content={
|
||||
<p className="text-danger text-center">{repositories.message}</p>
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<SuggestionBox
|
||||
title="Open a Repo"
|
||||
content={
|
||||
<GitHubAuth
|
||||
isLoggedIn={!!user && !isGitHubErrorReponse(user)}
|
||||
repositories={repositories || []}
|
||||
onConnectToGitHub={handleConnectToGitHub}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
{connectToGitHubModalOpen && (
|
||||
<ModalBackdrop onClose={() => setConnectToGitHubModalOpen(false)}>
|
||||
<ConnectToGitHubModal
|
||||
onClose={() => setConnectToGitHubModalOpen(false)}
|
||||
/>
|
||||
</ModalBackdrop>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -14,12 +14,14 @@ interface AccountSettingsModalProps {
|
||||
onClose: () => void;
|
||||
selectedLanguage: string;
|
||||
gitHubError: boolean;
|
||||
analyticsConsent: string | null;
|
||||
}
|
||||
|
||||
function AccountSettingsModal({
|
||||
onClose,
|
||||
selectedLanguage,
|
||||
gitHubError,
|
||||
analyticsConsent,
|
||||
}: AccountSettingsModalProps) {
|
||||
const data = useRouteLoaderData<typeof clientLoader>("routes/_oh");
|
||||
const settingsFetcher = useFetcher<typeof settingsClientAction>({
|
||||
@@ -32,6 +34,7 @@ function AccountSettingsModal({
|
||||
const formData = new FormData(event.currentTarget);
|
||||
const language = formData.get("language")?.toString();
|
||||
const ghToken = formData.get("ghToken")?.toString();
|
||||
const analytics = formData.get("analytics")?.toString() === "on";
|
||||
|
||||
const accountForm = new FormData();
|
||||
const loginForm = new FormData();
|
||||
@@ -44,6 +47,7 @@ function AccountSettingsModal({
|
||||
accountForm.append("language", languageKey ?? "en");
|
||||
}
|
||||
if (ghToken) loginForm.append("ghToken", ghToken);
|
||||
accountForm.append("analytics", analytics.toString());
|
||||
|
||||
settingsFetcher.submit(accountForm, {
|
||||
method: "POST",
|
||||
@@ -101,6 +105,15 @@ function AccountSettingsModal({
|
||||
)}
|
||||
</div>
|
||||
|
||||
<label className="flex gap-2 items-center self-start">
|
||||
<input
|
||||
name="analytics"
|
||||
type="checkbox"
|
||||
defaultChecked={analyticsConsent === "true"}
|
||||
/>
|
||||
Enable analytics
|
||||
</label>
|
||||
|
||||
<div className="flex flex-col gap-2 w-full">
|
||||
<ModalButton
|
||||
disabled={
|
||||
|
||||
@@ -21,13 +21,17 @@ export function BaseModalTitle({ title }: BaseModalTitleProps) {
|
||||
}
|
||||
|
||||
interface BaseModalDescriptionProps {
|
||||
description: React.ReactNode;
|
||||
description?: React.ReactNode;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
export function BaseModalDescription({
|
||||
description,
|
||||
children,
|
||||
}: BaseModalDescriptionProps) {
|
||||
return <span className="text-xs text-[#A3A3A3]">{description}</span>;
|
||||
return (
|
||||
<span className="text-xs text-[#A3A3A3]">{children || description}</span>
|
||||
);
|
||||
}
|
||||
|
||||
interface BaseModalProps {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from "react";
|
||||
import { Data } from "ws";
|
||||
import EventLogger from "#/utils/event-logger";
|
||||
import { getValidFallbackHost } from "#/utils/get-valid-fallback-host";
|
||||
|
||||
interface WebSocketClientOptions {
|
||||
token: string | null;
|
||||
@@ -46,12 +45,17 @@ function SocketProvider({ children }: SocketProviderProps) {
|
||||
);
|
||||
}
|
||||
|
||||
const fallback = getValidFallbackHost();
|
||||
const baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || fallback;
|
||||
const baseUrl =
|
||||
import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const ws = new WebSocket(
|
||||
`${protocol}//${baseUrl}/ws${options?.token ? `?token=${options.token}` : ""}`,
|
||||
);
|
||||
const sessionToken = options?.token || "NO_JWT"; // not allowed to be empty or duplicated
|
||||
const ghToken = localStorage.getItem("ghToken") || "NO_GITHUB";
|
||||
|
||||
const ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [
|
||||
"openhands",
|
||||
sessionToken,
|
||||
ghToken,
|
||||
]);
|
||||
|
||||
ws.addEventListener("open", (event) => {
|
||||
setIsConnected(true);
|
||||
|
||||
@@ -6,13 +6,25 @@
|
||||
*/
|
||||
|
||||
import { RemixBrowser } from "@remix-run/react";
|
||||
import { startTransition, StrictMode } from "react";
|
||||
import React, { startTransition, StrictMode } from "react";
|
||||
import { hydrateRoot } from "react-dom/client";
|
||||
import { Provider } from "react-redux";
|
||||
import posthog from "posthog-js";
|
||||
import { SocketProvider } from "./context/socket";
|
||||
import "./i18n";
|
||||
import store from "./store";
|
||||
|
||||
function PosthogInit() {
|
||||
React.useEffect(() => {
|
||||
posthog.init("phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA", {
|
||||
api_host: "https://us.i.posthog.com",
|
||||
person_profiles: "identified_only",
|
||||
});
|
||||
}, []);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async function prepareApp() {
|
||||
if (
|
||||
process.env.NODE_ENV === "development" &&
|
||||
@@ -34,6 +46,7 @@ prepareApp().then(() =>
|
||||
<SocketProvider>
|
||||
<Provider store={store}>
|
||||
<RemixBrowser />
|
||||
<PosthogInit />
|
||||
</Provider>
|
||||
</SocketProvider>
|
||||
</StrictMode>,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { delay, http, HttpResponse } from "msw";
|
||||
|
||||
const openHandsHandlers = [
|
||||
http.get("http://localhost:3000/api/options/models", async () => {
|
||||
http.get("http://localhost:3001/api/options/models", async () => {
|
||||
await delay();
|
||||
return HttpResponse.json([
|
||||
"gpt-3.5-turbo",
|
||||
@@ -10,17 +10,17 @@ const openHandsHandlers = [
|
||||
]);
|
||||
}),
|
||||
|
||||
http.get("http://localhost:3000/api/options/agents", async () => {
|
||||
http.get("http://localhost:3001/api/options/agents", async () => {
|
||||
await delay();
|
||||
return HttpResponse.json(["CodeActAgent", "CoActAgent"]);
|
||||
}),
|
||||
|
||||
http.get("http://localhost:3000/api/options/security-analyzers", async () => {
|
||||
http.get("http://localhost:3001/api/options/security-analyzers", async () => {
|
||||
await delay();
|
||||
return HttpResponse.json(["mock-invariant"]);
|
||||
}),
|
||||
|
||||
http.get("http://localhost:3000/api/list-files", async ({ request }) => {
|
||||
http.get("http://localhost:3001/api/list-files", async ({ request }) => {
|
||||
await delay();
|
||||
|
||||
const token = request.headers
|
||||
@@ -32,11 +32,11 @@ const openHandsHandlers = [
|
||||
return HttpResponse.json(["file1.ts", "dir1/file2.ts", "file3.ts"]);
|
||||
}),
|
||||
|
||||
http.post("http://localhost:3000/api/save-file", () =>
|
||||
http.post("http://localhost:3001/api/save-file", () =>
|
||||
HttpResponse.json(null, { status: 200 }),
|
||||
),
|
||||
|
||||
http.get("http://localhost:3000/api/select-file", async ({ request }) => {
|
||||
http.get("http://localhost:3001/api/select-file", async ({ request }) => {
|
||||
await delay();
|
||||
|
||||
const token = request.headers
|
||||
@@ -58,7 +58,7 @@ const openHandsHandlers = [
|
||||
return HttpResponse.json(null, { status: 404 });
|
||||
}),
|
||||
|
||||
http.post("http://localhost:3000/api/submit-feedback", async () => {
|
||||
http.post("http://localhost:3001/api/submit-feedback", async () => {
|
||||
await delay(1200);
|
||||
|
||||
return HttpResponse.json({
|
||||
@@ -70,7 +70,9 @@ const openHandsHandlers = [
|
||||
|
||||
export const handlers = [
|
||||
...openHandsHandlers,
|
||||
http.get("https://api.github.com/user/repos", ({ request }) => {
|
||||
http.get("https://api.github.com/user/repos", async ({ request }) => {
|
||||
await delay(3500);
|
||||
|
||||
const token = request.headers
|
||||
.get("Authorization")
|
||||
?.replace("Bearer", "")
|
||||
@@ -85,7 +87,10 @@ export const handlers = [
|
||||
{ id: 2, full_name: "octocat/earth" },
|
||||
]);
|
||||
}),
|
||||
http.post("http://localhost:3000/api/submit-feedback", async () =>
|
||||
http.post("http://localhost:3001/api/submit-feedback", async () =>
|
||||
HttpResponse.json({ statusCode: 200 }, { status: 200 }),
|
||||
),
|
||||
http.post("https://us.i.posthog.com/e", async () =>
|
||||
HttpResponse.json(null, { status: 200 }),
|
||||
),
|
||||
];
|
||||
|
||||
@@ -1,54 +1,23 @@
|
||||
import {
|
||||
Await,
|
||||
ClientActionFunctionArgs,
|
||||
ClientLoaderFunctionArgs,
|
||||
json,
|
||||
defer,
|
||||
redirect,
|
||||
useLoaderData,
|
||||
useRouteLoaderData,
|
||||
} from "@remix-run/react";
|
||||
import React from "react";
|
||||
import React, { Suspense } from "react";
|
||||
import { SuggestionBox } from "./suggestion-box";
|
||||
import { TaskForm } from "./task-form";
|
||||
import { HeroHeading } from "./hero-heading";
|
||||
import { GitHubRepositorySelector } from "./github-repo-selector";
|
||||
import {
|
||||
isGitHubErrorReponse,
|
||||
retrieveAllGitHubUserRepositories,
|
||||
} from "#/api/github";
|
||||
import ModalButton from "#/components/buttons/ModalButton";
|
||||
import GitHubLogo from "#/assets/branding/github-logo.svg?react";
|
||||
import { ConnectToGitHubModal } from "#/components/modals/connect-to-github-modal";
|
||||
import { ModalBackdrop } from "#/components/modals/modal-backdrop";
|
||||
import { retrieveAllGitHubUserRepositories } from "#/api/github";
|
||||
import store from "#/store";
|
||||
import { setInitialQuery } from "#/state/initial-query-slice";
|
||||
import { clientLoader as rootClientLoader } from "#/routes/_oh";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { generateGitHubAuthUrl } from "#/utils/generate-github-auth-url";
|
||||
|
||||
interface GitHubAuthProps {
|
||||
onConnectToGitHub: () => void;
|
||||
repositories: GitHubRepository[];
|
||||
isLoggedIn: boolean;
|
||||
}
|
||||
|
||||
function GitHubAuth({
|
||||
onConnectToGitHub,
|
||||
repositories,
|
||||
isLoggedIn,
|
||||
}: GitHubAuthProps) {
|
||||
if (isLoggedIn) {
|
||||
return <GitHubRepositorySelector repositories={repositories} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<ModalButton
|
||||
text="Connect to GitHub"
|
||||
icon={<GitHubLogo width={20} height={20} />}
|
||||
className="bg-[#791B80] w-full"
|
||||
onClick={onConnectToGitHub}
|
||||
/>
|
||||
);
|
||||
}
|
||||
import { GitHubRepositoriesSuggestionBox } from "#/components/github-repositories-suggestion-box";
|
||||
|
||||
export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
let isSaas = false;
|
||||
@@ -67,12 +36,12 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
const token = localStorage.getItem("token");
|
||||
if (token) return redirect("/app");
|
||||
|
||||
let repositories: GitHubRepository[] = [];
|
||||
let repositories: ReturnType<
|
||||
typeof retrieveAllGitHubUserRepositories
|
||||
> | null = null;
|
||||
if (ghToken) {
|
||||
const data = await retrieveAllGitHubUserRepositories(ghToken);
|
||||
if (!isGitHubErrorReponse(data)) {
|
||||
repositories = data;
|
||||
}
|
||||
const data = retrieveAllGitHubUserRepositories(ghToken);
|
||||
repositories = data;
|
||||
}
|
||||
|
||||
let githubAuthUrl: string | null = null;
|
||||
@@ -81,7 +50,7 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
githubAuthUrl = generateGitHubAuthUrl(githubClientId, requestUrl);
|
||||
}
|
||||
|
||||
return json({ repositories, githubAuthUrl });
|
||||
return defer({ repositories, githubAuthUrl });
|
||||
};
|
||||
|
||||
export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
|
||||
@@ -95,18 +64,8 @@ export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
|
||||
function Home() {
|
||||
const rootData = useRouteLoaderData<typeof rootClientLoader>("routes/_oh");
|
||||
const { repositories, githubAuthUrl } = useLoaderData<typeof clientLoader>();
|
||||
const [connectToGitHubModalOpen, setConnectToGitHubModalOpen] =
|
||||
React.useState(false);
|
||||
const [importedFile, setImportedFile] = React.useState<File | null>(null);
|
||||
|
||||
const handleConnectToGitHub = () => {
|
||||
if (githubAuthUrl) {
|
||||
window.location.href = githubAuthUrl;
|
||||
} else {
|
||||
setConnectToGitHubModalOpen(true);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="bg-root-secondary h-full rounded-xl flex flex-col items-center justify-center relative overflow-y-auto">
|
||||
<HeroHeading />
|
||||
@@ -115,18 +74,24 @@ function Home() {
|
||||
<TaskForm importedProjectZip={importedFile} />
|
||||
</div>
|
||||
<div className="flex gap-4 w-full">
|
||||
<SuggestionBox
|
||||
title="Open a Repo"
|
||||
content={
|
||||
<GitHubAuth
|
||||
isLoggedIn={
|
||||
!!rootData?.user && !isGitHubErrorReponse(rootData.user)
|
||||
}
|
||||
repositories={repositories}
|
||||
onConnectToGitHub={handleConnectToGitHub}
|
||||
<Suspense
|
||||
fallback={
|
||||
<SuggestionBox
|
||||
title="Open a Repo"
|
||||
content="Loading repositories..."
|
||||
/>
|
||||
}
|
||||
/>
|
||||
>
|
||||
<Await resolve={repositories}>
|
||||
{(resolvedRepositories) => (
|
||||
<GitHubRepositoriesSuggestionBox
|
||||
repositories={resolvedRepositories}
|
||||
gitHubAuthUrl={githubAuthUrl}
|
||||
user={rootData?.user || null}
|
||||
/>
|
||||
)}
|
||||
</Await>
|
||||
</Suspense>
|
||||
<SuggestionBox
|
||||
title={importedFile ? "Project Loaded" : "+ Import Project"}
|
||||
content={
|
||||
@@ -159,13 +124,6 @@ function Home() {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{connectToGitHubModalOpen && (
|
||||
<ModalBackdrop onClose={() => setConnectToGitHubModalOpen(false)}>
|
||||
<ConnectToGitHubModal
|
||||
onClose={() => setConnectToGitHubModalOpen(false)}
|
||||
/>
|
||||
</ModalBackdrop>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import { Editor, Monaco } from "@monaco-editor/react";
|
||||
import { Editor, EditorProps } from "@monaco-editor/react";
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { VscCode } from "react-icons/vsc";
|
||||
import { type editor } from "monaco-editor";
|
||||
import toast from "react-hot-toast";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useFiles } from "#/context/files";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
interface CodeEditorCompoonentProps {
|
||||
onMount: EditorProps["onMount"];
|
||||
isReadOnly: boolean;
|
||||
}
|
||||
|
||||
function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
|
||||
function CodeEditorCompoonent({
|
||||
onMount,
|
||||
isReadOnly,
|
||||
}: CodeEditorCompoonentProps) {
|
||||
const { t } = useTranslation();
|
||||
const {
|
||||
files,
|
||||
@@ -22,22 +25,6 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
|
||||
saveFileContent: saveNewFileContent,
|
||||
} = useFiles();
|
||||
|
||||
const handleEditorDidMount = React.useCallback(
|
||||
(editor: editor.IStandaloneCodeEditor, monaco: Monaco): void => {
|
||||
monaco.editor.defineTheme("my-theme", {
|
||||
base: "vs-dark",
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
"editor.background": "#171717",
|
||||
},
|
||||
});
|
||||
|
||||
monaco.editor.setTheme("my-theme");
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const handleEditorChange = (value: string | undefined) => {
|
||||
if (selectedPath && value) modifyFileContent(selectedPath, value);
|
||||
};
|
||||
@@ -49,8 +36,7 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
|
||||
|
||||
if (content) {
|
||||
try {
|
||||
const token = localStorage.getItem("token")?.toString();
|
||||
if (token) await OpenHands.saveFile(token, selectedPath, content);
|
||||
await OpenHands.saveFile(selectedPath, content);
|
||||
} catch (error) {
|
||||
toast.error("Failed to save file");
|
||||
}
|
||||
@@ -68,7 +54,7 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
|
||||
return (
|
||||
<div
|
||||
data-testid="code-editor-empty-message"
|
||||
className="flex flex-col items-center text-neutral-400"
|
||||
className="flex flex-col h-full items-center justify-center text-neutral-400"
|
||||
>
|
||||
<VscCode size={100} />
|
||||
{t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
|
||||
@@ -79,7 +65,6 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
|
||||
return (
|
||||
<Editor
|
||||
data-testid="code-editor"
|
||||
height="100%"
|
||||
path={selectedPath ?? undefined}
|
||||
defaultValue=""
|
||||
value={
|
||||
@@ -87,7 +72,7 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
|
||||
? modifiedFiles[selectedPath] || files[selectedPath]
|
||||
: undefined
|
||||
}
|
||||
onMount={handleEditorDidMount}
|
||||
onMount={onMount}
|
||||
onChange={handleEditorChange}
|
||||
options={{ readOnly: isReadOnly }}
|
||||
/>
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import React from "react";
|
||||
import { useSelector } from "react-redux";
|
||||
import { json, useLoaderData, useRouteError } from "@remix-run/react";
|
||||
import { json, useRouteError } from "@remix-run/react";
|
||||
import toast from "react-hot-toast";
|
||||
import { editor } from "monaco-editor";
|
||||
import { EditorProps } from "@monaco-editor/react";
|
||||
import { RootState } from "#/store";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import FileExplorer from "#/components/file-explorer/FileExplorer";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { useSocket } from "#/context/socket";
|
||||
import CodeEditorCompoonent from "./code-editor-component";
|
||||
import { useFiles } from "#/context/files";
|
||||
import { EditorActions } from "#/components/editor-actions";
|
||||
@@ -28,8 +29,7 @@ export function ErrorBoundary() {
|
||||
}
|
||||
|
||||
function CodeEditor() {
|
||||
const { token } = useLoaderData<typeof clientLoader>();
|
||||
const { runtimeActive } = useSocket();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const {
|
||||
setPaths,
|
||||
selectedPath,
|
||||
@@ -37,6 +37,27 @@ function CodeEditor() {
|
||||
saveFileContent: saveNewFileContent,
|
||||
discardChanges,
|
||||
} = useFiles();
|
||||
const [fileExplorerIsOpen, setFileExplorerIsOpen] = React.useState(true);
|
||||
const editorRef = React.useRef<editor.IStandaloneCodeEditor | null>(null);
|
||||
|
||||
const toggleFileExplorer = () => {
|
||||
setFileExplorerIsOpen((prev) => !prev);
|
||||
editorRef.current?.layout({ width: 0, height: 0 });
|
||||
};
|
||||
|
||||
const handleEditorDidMount: EditorProps["onMount"] = (e, monaco) => {
|
||||
editorRef.current = e;
|
||||
|
||||
monaco.editor.defineTheme("oh-dark", {
|
||||
base: "vs-dark",
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
"editor.background": "#171717",
|
||||
},
|
||||
});
|
||||
monaco.editor.setTheme("oh-dark");
|
||||
};
|
||||
|
||||
const [errors, setErrors] = React.useState<{ getFiles: string | null }>({
|
||||
getFiles: null,
|
||||
@@ -47,15 +68,14 @@ function CodeEditor() {
|
||||
);
|
||||
|
||||
React.useEffect(() => {
|
||||
// only retrieve files if connected to WS to prevent requesting before runtime is ready
|
||||
if (runtimeActive && token) {
|
||||
OpenHands.getFiles(token)
|
||||
if (curAgentState === AgentState.INIT) {
|
||||
OpenHands.getFiles()
|
||||
.then(setPaths)
|
||||
.catch(() => {
|
||||
setErrors({ getFiles: "Failed to retrieve files" });
|
||||
});
|
||||
}
|
||||
}, [runtimeActive, token]);
|
||||
}, [curAgentState]);
|
||||
|
||||
// Code editing is only allowed when the agent is paused, finished, or awaiting user input (server rules)
|
||||
const isEditingAllowed = React.useMemo(
|
||||
@@ -69,9 +89,9 @@ function CodeEditor() {
|
||||
const handleSave = async () => {
|
||||
if (selectedPath) {
|
||||
const content = modifiedFiles[selectedPath];
|
||||
if (content && token) {
|
||||
if (content) {
|
||||
try {
|
||||
await OpenHands.saveFile(token, selectedPath, content);
|
||||
await OpenHands.saveFile(selectedPath, content);
|
||||
saveNewFileContent(selectedPath);
|
||||
} catch (error) {
|
||||
toast.error("Failed to save file");
|
||||
@@ -85,9 +105,13 @@ function CodeEditor() {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-full bg-neutral-900 relative">
|
||||
<FileExplorer error={errors.getFiles} />
|
||||
<div className="flex flex-col min-h-0 w-full">
|
||||
<div className="flex h-full bg-neutral-900 relative">
|
||||
<FileExplorer
|
||||
isOpen={fileExplorerIsOpen}
|
||||
onToggle={toggleFileExplorer}
|
||||
error={errors.getFiles}
|
||||
/>
|
||||
<div className="w-full">
|
||||
{selectedPath && (
|
||||
<div className="flex w-full items-center justify-between self-end p-2">
|
||||
<span className="text-sm text-neutral-500">{selectedPath}</span>
|
||||
@@ -98,9 +122,10 @@ function CodeEditor() {
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex grow items-center justify-center">
|
||||
<CodeEditorCompoonent isReadOnly={!isEditingAllowed} />
|
||||
</div>
|
||||
<CodeEditorCompoonent
|
||||
onMount={handleEditorDidMount}
|
||||
isReadOnly={!isEditingAllowed}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -72,9 +72,8 @@ const isAgentStateChange = (
|
||||
|
||||
export const clientLoader = async () => {
|
||||
const ghToken = localStorage.getItem("ghToken");
|
||||
|
||||
try {
|
||||
const isAuthed = await userIsAuthenticated(ghToken);
|
||||
const isAuthed = await userIsAuthenticated();
|
||||
if (!isAuthed) {
|
||||
clearSession();
|
||||
return redirect("/");
|
||||
@@ -290,21 +289,21 @@ function App() {
|
||||
|
||||
React.useEffect(() => {
|
||||
(async () => {
|
||||
if (runtimeActive && token && importedProjectZip) {
|
||||
if (runtimeActive && importedProjectZip) {
|
||||
// upload files action
|
||||
try {
|
||||
const blob = base64ToBlob(importedProjectZip);
|
||||
const file = new File([blob], "imported-project.zip", {
|
||||
type: blob.type,
|
||||
});
|
||||
await OpenHands.uploadFiles(token, [file]);
|
||||
await OpenHands.uploadFiles([file]);
|
||||
dispatch(setImportedProjectZip(null));
|
||||
} catch (error) {
|
||||
toast.error("Failed to upload project files.");
|
||||
}
|
||||
}
|
||||
})();
|
||||
}, [runtimeActive, token, importedProjectZip]);
|
||||
}, [runtimeActive, importedProjectZip]);
|
||||
|
||||
const {
|
||||
isOpen: securityModalIsOpen,
|
||||
@@ -315,7 +314,7 @@ function App() {
|
||||
return (
|
||||
<div className="flex flex-col h-full gap-3">
|
||||
<div className="flex h-full overflow-auto gap-3">
|
||||
<Container className="w-[375px] max-h-full">
|
||||
<Container className="w-[390px] max-h-full">
|
||||
<ChatInterface />
|
||||
</Container>
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ import {
|
||||
Outlet,
|
||||
ClientLoaderFunctionArgs,
|
||||
} from "@remix-run/react";
|
||||
import posthog from "posthog-js";
|
||||
import { useDispatch } from "react-redux";
|
||||
import { retrieveGitHubUser, isGitHubErrorReponse } from "#/api/github";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import CogTooth from "#/assets/cog-tooth";
|
||||
@@ -28,6 +30,9 @@ import DocsIcon from "#/assets/docs.svg?react";
|
||||
import { userIsAuthenticated } from "#/utils/user-is-authenticated";
|
||||
import { generateGitHubAuthUrl } from "#/utils/generate-github-auth-url";
|
||||
import { WaitlistModal } from "#/components/waitlist-modal";
|
||||
import { AnalyticsConsentFormModal } from "#/components/analytics-consent-form-modal";
|
||||
import { setCurrentAgentState } from "#/state/agentSlice";
|
||||
import AgentState from "#/types/AgentState";
|
||||
|
||||
export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
try {
|
||||
@@ -41,12 +46,20 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
|
||||
let token = localStorage.getItem("token");
|
||||
const ghToken = localStorage.getItem("ghToken");
|
||||
const analyticsConsent = localStorage.getItem("analytics-consent");
|
||||
const userConsents = analyticsConsent === "true";
|
||||
|
||||
let isAuthed: boolean = false;
|
||||
if (!userConsents) {
|
||||
posthog.opt_out_capturing();
|
||||
} else {
|
||||
posthog.opt_in_capturing();
|
||||
}
|
||||
|
||||
let isAuthed = false;
|
||||
let githubAuthUrl: string | null = null;
|
||||
|
||||
try {
|
||||
isAuthed = await userIsAuthenticated(ghToken);
|
||||
isAuthed = await userIsAuthenticated();
|
||||
if (!isAuthed && window.__GITHUB_CLIENT_ID__) {
|
||||
const requestUrl = new URL(request.url);
|
||||
githubAuthUrl = generateGitHubAuthUrl(
|
||||
@@ -79,6 +92,7 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
user,
|
||||
settingsIsUpdated,
|
||||
settings,
|
||||
analyticsConsent,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -132,9 +146,11 @@ export default function MainApp() {
|
||||
githubAuthUrl,
|
||||
settingsIsUpdated,
|
||||
settings,
|
||||
analyticsConsent,
|
||||
} = useLoaderData<typeof clientLoader>();
|
||||
const logoutFetcher = useFetcher({ key: "logout" });
|
||||
const endSessionFetcher = useFetcher({ key: "end-session" });
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const [accountSettingsModalOpen, setAccountSettingsModalOpen] =
|
||||
React.useState(false);
|
||||
@@ -204,6 +220,7 @@ export default function MainApp() {
|
||||
|
||||
const handleEndSession = () => {
|
||||
setStartNewProjectModalIsOpen(false);
|
||||
dispatch(setCurrentAgentState(AgentState.LOADING));
|
||||
// call new session action and redirect to '/'
|
||||
endSessionFetcher.submit(new FormData(), {
|
||||
method: "POST",
|
||||
@@ -304,6 +321,7 @@ export default function MainApp() {
|
||||
onClose={handleAccountSettingsModalClose}
|
||||
selectedLanguage={settings.LANGUAGE}
|
||||
gitHubError={isGitHubErrorReponse(user)}
|
||||
analyticsConsent={analyticsConsent}
|
||||
/>
|
||||
</ModalBackdrop>
|
||||
)}
|
||||
@@ -328,6 +346,7 @@ export default function MainApp() {
|
||||
{!isAuthed && (
|
||||
<WaitlistModal ghToken={ghToken} githubAuthUrl={githubAuthUrl} />
|
||||
)}
|
||||
{!analyticsConsent && <AnalyticsConsentFormModal />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
|
||||
const code = url.searchParams.get("code");
|
||||
|
||||
if (code) {
|
||||
// request to the server to exchange the code for a token
|
||||
const { access_token: accessToken } =
|
||||
await OpenHands.getGitHubAccessToken(code);
|
||||
// set the token in local storage
|
||||
|
||||
localStorage.setItem("ghToken", accessToken);
|
||||
|
||||
return redirect("/");
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import { ClientActionFunctionArgs, json } from "@remix-run/react";
|
||||
|
||||
export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
|
||||
const formData = await request.formData();
|
||||
const userConsents = formData.get("analytics") === "on";
|
||||
localStorage.setItem("analytics-consent", userConsents.toString());
|
||||
|
||||
return json(null);
|
||||
};
|
||||
@@ -28,6 +28,9 @@ export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
|
||||
const LANGUAGE = formData.get("language")?.toString();
|
||||
if (LANGUAGE) saveSettings({ LANGUAGE });
|
||||
|
||||
const ANALYTICS = formData.get("analytics")?.toString() ?? "false";
|
||||
localStorage.setItem("analytics-consent", ANALYTICS);
|
||||
|
||||
return json({ success: true });
|
||||
}
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import { ClientActionFunctionArgs, json } from "@remix-run/react";
|
||||
import { Feedback } from "#/api/open-hands.types";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
const VIEWER_PAGE = "https://www.all-hands.dev/share";
|
||||
|
||||
const isFeedback = (feedback: unknown): feedback is Feedback => {
|
||||
if (typeof feedback !== "object" || feedback === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
"version" in feedback &&
|
||||
"email" in feedback &&
|
||||
"token" in feedback &&
|
||||
"feedback" in feedback &&
|
||||
"permissions" in feedback &&
|
||||
"trajectory" in feedback
|
||||
);
|
||||
};
|
||||
|
||||
export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
|
||||
const formData = await request.formData();
|
||||
const feedback = formData.get("feedback")?.toString();
|
||||
const token = localStorage.getItem("token");
|
||||
|
||||
if (token && feedback) {
|
||||
const parsed = JSON.parse(feedback);
|
||||
if (isFeedback(parsed)) {
|
||||
try {
|
||||
const response = await OpenHands.sendFeedback(token, parsed);
|
||||
if (response.statusCode === 200) {
|
||||
const { message, feedback_id: feedbackId, password } = response.body;
|
||||
const link = `${VIEWER_PAGE}?share_id=${feedbackId}`;
|
||||
return json({
|
||||
success: true,
|
||||
data: { message, link, password },
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
return json({ success: false, data: null });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return json({ success: false, data: null });
|
||||
};
|
||||
@@ -1,14 +1,26 @@
|
||||
import { getToken } from "./auth";
|
||||
import { getToken, getGitHubToken } from "./auth";
|
||||
import toast from "#/utils/toast";
|
||||
|
||||
const WAIT_FOR_AUTH_DELAY_MS = 500;
|
||||
|
||||
const UNAUTHED_ROUTE_PREFIXES = [
|
||||
"/api/authenticate",
|
||||
"/api/options/",
|
||||
"/config.json",
|
||||
"/api/github/callback",
|
||||
];
|
||||
|
||||
export async function request(
|
||||
url: string,
|
||||
options: RequestInit = {},
|
||||
disableToast: boolean = false,
|
||||
returnResponse: boolean = false,
|
||||
maxRetries: number = 3,
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
): Promise<any> {
|
||||
if (maxRetries < 0) {
|
||||
throw new Error("Max retries exceeded");
|
||||
}
|
||||
const onFail = (msg: string) => {
|
||||
if (!disableToast) {
|
||||
toast.error("api", msg);
|
||||
@@ -16,12 +28,17 @@ export async function request(
|
||||
throw new Error(msg);
|
||||
};
|
||||
|
||||
const needsAuth = !url.startsWith("/api/options/");
|
||||
const needsAuth = !UNAUTHED_ROUTE_PREFIXES.some((prefix) =>
|
||||
url.startsWith(prefix),
|
||||
);
|
||||
const token = getToken();
|
||||
const githubToken = getGitHubToken();
|
||||
if (!token && needsAuth) {
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => {
|
||||
resolve(request(url, options, disableToast));
|
||||
resolve(
|
||||
request(url, options, disableToast, returnResponse, maxRetries - 1),
|
||||
);
|
||||
}, WAIT_FOR_AUTH_DELAY_MS);
|
||||
});
|
||||
}
|
||||
@@ -32,6 +49,13 @@ export async function request(
|
||||
Authorization: `Bearer ${token}`,
|
||||
};
|
||||
}
|
||||
if (githubToken) {
|
||||
// eslint-disable-next-line no-param-reassign
|
||||
options.headers = {
|
||||
...(options.headers || {}),
|
||||
"X-GitHub-Token": githubToken,
|
||||
};
|
||||
}
|
||||
|
||||
let response = null;
|
||||
try {
|
||||
@@ -48,6 +72,10 @@ export async function request(
|
||||
onFail(`Error fetching ${url}: ${response?.statusText}`);
|
||||
}
|
||||
|
||||
if (returnResponse) {
|
||||
return response;
|
||||
}
|
||||
|
||||
try {
|
||||
return await (response && response.json());
|
||||
} catch (e) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const TOKEN_KEY = "token";
|
||||
const GITHUB_TOKEN_KEY = "ghToken";
|
||||
|
||||
const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? "";
|
||||
|
||||
@@ -10,4 +11,22 @@ const setToken = (token: string): void => {
|
||||
localStorage.setItem(TOKEN_KEY, token);
|
||||
};
|
||||
|
||||
export { getToken, setToken, clearToken };
|
||||
const getGitHubToken = (): string =>
|
||||
localStorage.getItem(GITHUB_TOKEN_KEY) ?? "";
|
||||
|
||||
const setGitHubToken = (token: string): void => {
|
||||
localStorage.setItem(GITHUB_TOKEN_KEY, token);
|
||||
};
|
||||
|
||||
const clearGitHubToken = (): void => {
|
||||
localStorage.removeItem(GITHUB_TOKEN_KEY);
|
||||
};
|
||||
|
||||
export {
|
||||
getToken,
|
||||
setToken,
|
||||
clearToken,
|
||||
getGitHubToken,
|
||||
setGitHubToken,
|
||||
clearGitHubToken,
|
||||
};
|
||||
|
||||
@@ -4,12 +4,7 @@ import OpenHands from "#/api/open-hands";
|
||||
* Downloads the current workspace as a .zip file.
|
||||
*/
|
||||
export const downloadWorkspace = async () => {
|
||||
const token = localStorage.getItem("token");
|
||||
if (!token) {
|
||||
throw new Error("No token found");
|
||||
}
|
||||
|
||||
const blob = await OpenHands.getWorkspaceZip(token);
|
||||
const blob = await OpenHands.getWorkspaceZip();
|
||||
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
/**
|
||||
* Get the valid fallback host. Returns the host unless it is localhost, in which case it returns localhost:3000
|
||||
* @returns Valid fallback host
|
||||
*
|
||||
* @example
|
||||
* // If the host is localhost (e.g., localhost:5173), it returns localhost:3000
|
||||
* const host = getValidFallbackHost(); // localhost:3000
|
||||
*
|
||||
* // If the host is not localhost, it returns the host
|
||||
* const host = getValidFallbackHost(); // sub.example.com
|
||||
*/
|
||||
export const getValidFallbackHost = () => {
|
||||
if (typeof window !== "undefined") {
|
||||
return window.location.host;
|
||||
}
|
||||
|
||||
// Fallback is localhost:3000 because that is the default port for the server
|
||||
return "localhost:3000";
|
||||
};
|
||||
@@ -1,16 +1,12 @@
|
||||
import { retrieveGitHubUser, isGitHubErrorReponse } from "#/api/github";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
export const userIsAuthenticated = async (ghToken: string | null) => {
|
||||
if (window.__APP_MODE__ !== "saas") return true;
|
||||
export const userIsAuthenticated = async () => {
|
||||
if (window.__APP_MODE__ === "oss") return true;
|
||||
|
||||
let user: GitHubUser | GitHubErrorReponse | null = null;
|
||||
if (ghToken) user = await retrieveGitHubUser(ghToken);
|
||||
|
||||
if (user && !isGitHubErrorReponse(user)) {
|
||||
const isAuthed = await OpenHands.isAuthenticated(user.login);
|
||||
return isAuthed;
|
||||
try {
|
||||
await OpenHands.authenticate();
|
||||
return true;
|
||||
} catch (error) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
+21
-13
@@ -37,21 +37,29 @@ export const removeUnwantedKeys = (
|
||||
"focused_element_bid",
|
||||
];
|
||||
|
||||
return data.map((item) => {
|
||||
// Create a shallow copy of item
|
||||
const newItem = { ...item };
|
||||
return data
|
||||
.filter((item) => {
|
||||
// Skip items that have a status key
|
||||
if ("status" in item) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map((item) => {
|
||||
// Create a shallow copy of item
|
||||
const newItem = { ...item };
|
||||
|
||||
// Check if extras exists and delete it from a new extras object
|
||||
if (newItem.extras) {
|
||||
const newExtras = { ...newItem.extras };
|
||||
UNDESIRED_KEYS.forEach((key) => {
|
||||
delete newExtras[key as keyof typeof newExtras];
|
||||
});
|
||||
newItem.extras = newExtras;
|
||||
}
|
||||
// Check if extras exists and delete it from a new extras object
|
||||
if (newItem.extras) {
|
||||
const newExtras = { ...newItem.extras };
|
||||
UNDESIRED_KEYS.forEach((key) => {
|
||||
delete newExtras[key as keyof typeof newExtras];
|
||||
});
|
||||
newItem.extras = newExtras;
|
||||
}
|
||||
|
||||
return newItem;
|
||||
});
|
||||
return newItem;
|
||||
});
|
||||
};
|
||||
|
||||
export const removeApiKey = (
|
||||
|
||||
@@ -16,6 +16,7 @@ from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
BrowseInteractiveAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
IPythonRunCellAction,
|
||||
@@ -23,6 +24,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
BrowserOutputObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
IPythonRunCellObservation,
|
||||
@@ -42,7 +44,7 @@ from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class CodeActAgent(Agent):
|
||||
VERSION = '2.1'
|
||||
VERSION = '2.2'
|
||||
"""
|
||||
The Code Act Agent is a minimalist agent.
|
||||
The agent works by passing the model a list of action-observation pairs and prompting the model to take the next step.
|
||||
@@ -105,11 +107,11 @@ class CodeActAgent(Agent):
|
||||
if self.function_calling_active:
|
||||
# Function calling mode
|
||||
self.tools = codeact_function_calling.get_tools(
|
||||
codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate,
|
||||
codeact_enable_browsing=self.config.codeact_enable_browsing,
|
||||
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
|
||||
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}'
|
||||
)
|
||||
self.system_prompt = codeact_function_calling.SYSTEM_PROMPT
|
||||
@@ -142,10 +144,10 @@ class CodeActAgent(Agent):
|
||||
|
||||
Args:
|
||||
action (Action): The action to convert. Can be one of:
|
||||
- AgentDelegateAction: For delegating tasks to other agents
|
||||
- CmdRunAction: For executing bash commands
|
||||
- IPythonRunCellAction: For running IPython code
|
||||
- FileEditAction: For editing files
|
||||
- BrowseInteractiveAction: For browsing the web
|
||||
- AgentFinishAction: For ending the interaction
|
||||
- MessageAction: For sending messages
|
||||
pending_tool_call_action_messages (dict[str, Message]): Dictionary mapping response IDs
|
||||
@@ -169,6 +171,7 @@ class CodeActAgent(Agent):
|
||||
CmdRunAction,
|
||||
IPythonRunCellAction,
|
||||
FileEditAction,
|
||||
BrowseInteractiveAction,
|
||||
),
|
||||
) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
|
||||
if self.function_calling_active:
|
||||
@@ -185,13 +188,17 @@ class CodeActAgent(Agent):
|
||||
pending_tool_call_action_messages[llm_response.id] = Message(
|
||||
role=assistant_msg.role,
|
||||
# tool call content SHOULD BE a string
|
||||
content=[TextContent(text=assistant_msg.content)]
|
||||
content=[TextContent(text=assistant_msg.content or '')]
|
||||
if assistant_msg.content is not None
|
||||
else [],
|
||||
tool_calls=assistant_msg.tool_calls,
|
||||
)
|
||||
return []
|
||||
else:
|
||||
assert not isinstance(action, BrowseInteractiveAction), (
|
||||
'BrowseInteractiveAction is not supported in non-function calling mode. Action: '
|
||||
+ str(action)
|
||||
)
|
||||
content = [TextContent(text=self.action_parser.action_to_str(action))]
|
||||
return [
|
||||
Message(
|
||||
@@ -201,7 +208,7 @@ class CodeActAgent(Agent):
|
||||
]
|
||||
elif isinstance(action, MessageAction):
|
||||
role = 'user' if action.source == 'user' else 'assistant'
|
||||
content = [TextContent(text=action.content)]
|
||||
content = [TextContent(text=action.content or '')]
|
||||
if self.llm.vision_is_active() and action.images_urls:
|
||||
content.append(ImageContent(image_urls=action.images_urls))
|
||||
return [
|
||||
@@ -266,6 +273,12 @@ class CodeActAgent(Agent):
|
||||
elif isinstance(obs, FileEditObservation):
|
||||
text = obs_prefix + truncate_content(str(obs), max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, BrowserOutputObservation):
|
||||
text = obs.get_agent_obs_text()
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[TextContent(text=obs_prefix + text)],
|
||||
)
|
||||
elif isinstance(obs, AgentDelegateObservation):
|
||||
text = obs_prefix + truncate_content(
|
||||
obs.outputs['content'] if 'content' in obs.outputs else '',
|
||||
@@ -335,6 +348,7 @@ class CodeActAgent(Agent):
|
||||
}
|
||||
if self.function_calling_active:
|
||||
params['tools'] = self.tools
|
||||
params['parallel_tool_calls'] = False
|
||||
else:
|
||||
params['stop'] = [
|
||||
'</execute_ipython>',
|
||||
|
||||
@@ -5,6 +5,7 @@ This is similar to the functionality of `CodeActResponseParser`.
|
||||
|
||||
import json
|
||||
|
||||
from browsergym.core.action.highlevel import HighLevelActionSet
|
||||
from litellm import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
@@ -16,6 +17,7 @@ from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
BrowseInteractiveAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
IPythonRunCellAction,
|
||||
@@ -272,24 +274,146 @@ StrReplaceEditorTool = ChatCompletionToolParam(
|
||||
),
|
||||
)
|
||||
|
||||
_BROWSER_DELEGATION = """Delegate the task to another browsing agent.
|
||||
The assistant should delegate the task if it needs to browse the Internet.
|
||||
# from browsergym/core/action/highlevel.py
|
||||
_browser_action_space = HighLevelActionSet(
|
||||
subsets=['bid', 'nav'],
|
||||
strict=False, # less strict on the parsing of the actions
|
||||
multiaction=True, # enable to agent to take multiple actions at once
|
||||
)
|
||||
|
||||
|
||||
_BROWSER_DESCRIPTION = """Interact with the browser using Python code.
|
||||
The following 15 functions are available. Nothing else is supported.
|
||||
|
||||
goto(url: str)
|
||||
Description: Navigate to a url.
|
||||
Examples:
|
||||
goto('http://www.example.com')
|
||||
|
||||
go_back()
|
||||
Description: Navigate to the previous page in history.
|
||||
Examples:
|
||||
go_back()
|
||||
|
||||
go_forward()
|
||||
Description: Navigate to the next page in history.
|
||||
Examples:
|
||||
go_forward()
|
||||
|
||||
noop(wait_ms: float = 1000)
|
||||
Description: Do nothing, and optionally wait for the given time (in milliseconds).
|
||||
You can use this to get the current page content and/or wait for the page to load.
|
||||
Examples:
|
||||
noop()
|
||||
|
||||
noop(500)
|
||||
|
||||
scroll(delta_x: float, delta_y: float)
|
||||
Description: Scroll horizontally and vertically. Amounts in pixels, positive for right or down scrolling, negative for left or up scrolling. Dispatches a wheel event.
|
||||
Examples:
|
||||
scroll(0, 200)
|
||||
|
||||
scroll(-50.2, -100.5)
|
||||
|
||||
fill(bid: str, value: str)
|
||||
Description: Fill out a form field. It focuses the element and triggers an input event with the entered text. It works for <input>, <textarea> and [contenteditable] elements.
|
||||
Examples:
|
||||
fill('237', 'example value')
|
||||
|
||||
fill('45', 'multi-line\nexample')
|
||||
|
||||
fill('a12', 'example with "quotes"')
|
||||
|
||||
select_option(bid: str, options: str | list[str])
|
||||
Description: Select one or multiple options in a <select> element. You can specify option value or label to select. Multiple options can be selected.
|
||||
Examples:
|
||||
select_option('a48', 'blue')
|
||||
|
||||
select_option('c48', ['red', 'green', 'blue'])
|
||||
|
||||
click(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
|
||||
Description: Click an element.
|
||||
Examples:
|
||||
click('a51')
|
||||
|
||||
click('b22', button='right')
|
||||
|
||||
click('48', button='middle', modifiers=['Shift'])
|
||||
|
||||
dblclick(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
|
||||
Description: Double click an element.
|
||||
Examples:
|
||||
dblclick('12')
|
||||
|
||||
dblclick('ca42', button='right')
|
||||
|
||||
dblclick('178', button='middle', modifiers=['Shift'])
|
||||
|
||||
hover(bid: str)
|
||||
Description: Hover over an element.
|
||||
Examples:
|
||||
hover('b8')
|
||||
|
||||
press(bid: str, key_comb: str)
|
||||
Description: Focus the matching element and press a combination of keys. It accepts the logical key names that are emitted in the keyboardEvent.key property of the keyboard events: Backquote, Minus, Equal, Backslash, Backspace, Tab, Delete, Escape, ArrowDown, End, Enter, Home, Insert, PageDown, PageUp, ArrowRight, ArrowUp, F1 - F12, Digit0 - Digit9, KeyA - KeyZ, etc. You can alternatively specify a single character you'd like to produce such as "a" or "#". Following modification shortcuts are also supported: Shift, Control, Alt, Meta, ShiftLeft, ControlOrMeta. ControlOrMeta resolves to Control on Windows and Linux and to Meta on macOS.
|
||||
Examples:
|
||||
press('88', 'Backspace')
|
||||
|
||||
press('a26', 'ControlOrMeta+a')
|
||||
|
||||
press('a61', 'Meta+Shift+t')
|
||||
|
||||
focus(bid: str)
|
||||
Description: Focus the matching element.
|
||||
Examples:
|
||||
focus('b455')
|
||||
|
||||
clear(bid: str)
|
||||
Description: Clear the input field.
|
||||
Examples:
|
||||
clear('996')
|
||||
|
||||
drag_and_drop(from_bid: str, to_bid: str)
|
||||
Description: Perform a drag & drop. Hover the element that will be dragged. Press left mouse button. Move mouse to the element that will receive the drop. Release left mouse button.
|
||||
Examples:
|
||||
drag_and_drop('56', '498')
|
||||
|
||||
upload_file(bid: str, file: str | list[str])
|
||||
Description: Click an element and wait for a "filechooser" event, then select one or multiple input files for upload. Relative file paths are resolved relative to the current working directory. An empty list clears the selected files.
|
||||
Examples:
|
||||
upload_file('572', '/home/user/my_receipt.pdf')
|
||||
|
||||
upload_file('63', ['/home/bob/Documents/image.jpg', '/home/bob/Documents/file.zip'])
|
||||
|
||||
Multiple actions can be provided at once, but will be executed sequentially without any feedback from the page.
|
||||
More than 2-3 actions usually leads to failure or unexpected behavior. Example:
|
||||
fill('a12', 'example with "quotes"')
|
||||
click('a51')
|
||||
click('48', button='middle', modifiers=['Shift'])
|
||||
"""
|
||||
|
||||
BrowserDelegationTool = ChatCompletionToolParam(
|
||||
for _, action in _browser_action_space.action_set.items():
|
||||
assert (
|
||||
action.signature in _BROWSER_DESCRIPTION
|
||||
), f'Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.signature}'
|
||||
assert (
|
||||
action.description in _BROWSER_DESCRIPTION
|
||||
), f'Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.description}'
|
||||
|
||||
BrowserTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='delegate_to_browsing_agent',
|
||||
description=_BROWSER_DELEGATION,
|
||||
name='browser',
|
||||
description=_BROWSER_DESCRIPTION,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'task': {
|
||||
'code': {
|
||||
'type': 'string',
|
||||
'description': 'The task for the browsing agent to execute. It should include all the necessary context and specify what information the browsing agent should return.',
|
||||
},
|
||||
'description': 'The Python code that interacts with the browser.',
|
||||
}
|
||||
},
|
||||
'required': ['task'],
|
||||
'required': ['code'],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -357,6 +481,8 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
f'TOOL CALL: str_replace_editor -> file_editor with code: {code}'
|
||||
)
|
||||
action = IPythonRunCellAction(code=code, include_extra=False)
|
||||
elif tool_call.function.name == 'browser':
|
||||
action = BrowseInteractiveAction(browser_actions=arguments['code'])
|
||||
else:
|
||||
raise RuntimeError(f'Unknown tool call: {tool_call.function.name}')
|
||||
|
||||
@@ -381,13 +507,13 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
|
||||
|
||||
def get_tools(
|
||||
codeact_enable_browsing_delegate: bool = False,
|
||||
codeact_enable_browsing: bool = False,
|
||||
codeact_enable_llm_editor: bool = False,
|
||||
codeact_enable_jupyter: bool = False,
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
tools = [CmdRunTool, FinishTool]
|
||||
if codeact_enable_browsing_delegate:
|
||||
tools.append(BrowserDelegationTool)
|
||||
if codeact_enable_browsing:
|
||||
tools.append(BrowserTool)
|
||||
if codeact_enable_jupyter:
|
||||
tools.append(IPythonTool)
|
||||
if codeact_enable_llm_editor:
|
||||
|
||||
@@ -156,7 +156,7 @@ class AgentController:
|
||||
if exception is not None and isinstance(exception, litellm.AuthenticationError):
|
||||
detail = 'Please check your credentials. Is your API key correct?'
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation(f'{message}:{detail}'), EventSource.USER
|
||||
ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
async def start_step_loop(self):
|
||||
@@ -346,7 +346,8 @@ class AgentController:
|
||||
|
||||
self.state.agent_state = new_state
|
||||
self.event_stream.add_event(
|
||||
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
|
||||
AgentStateChangedObservation('', self.state.agent_state),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
if new_state == AgentState.INIT and self.state.resume_state:
|
||||
@@ -423,7 +424,8 @@ class AgentController:
|
||||
if self._is_stuck():
|
||||
# This need to go BEFORE report_error to sync metrics
|
||||
self.event_stream.add_event(
|
||||
FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
|
||||
FatalErrorObservation('Agent got stuck in a loop'),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -507,6 +509,16 @@ class AgentController:
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# emit AgentDelegateObservation when the delegate terminates due to error
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} encountered an error during execution.'
|
||||
)
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# close the delegate upon error
|
||||
await self.delegate.close()
|
||||
self.delegate = None
|
||||
@@ -532,9 +544,7 @@ class AgentController:
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
obs: Observation = AgentDelegateObservation(
|
||||
outputs=outputs, content=content
|
||||
)
|
||||
obs = AgentDelegateObservation(outputs=outputs, content=content)
|
||||
|
||||
# clean up delegate status
|
||||
self.delegate = None
|
||||
|
||||
@@ -61,7 +61,7 @@ def display_event(event: Event):
|
||||
if hasattr(event, 'thought'):
|
||||
display_message(event.thought)
|
||||
if isinstance(event, MessageAction):
|
||||
if event.source != EventSource.USER:
|
||||
if event.source == EventSource.AGENT:
|
||||
display_message(event.content)
|
||||
if isinstance(event, CmdRunAction):
|
||||
display_command(event.command)
|
||||
@@ -131,7 +131,7 @@ async def main():
|
||||
next_message = input('How can I help? >> ')
|
||||
if next_message == 'exit':
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
|
||||
ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
|
||||
)
|
||||
return
|
||||
action = MessageAction(content=next_message)
|
||||
|
||||
@@ -9,7 +9,7 @@ class AgentConfig:
|
||||
|
||||
Attributes:
|
||||
function_calling: Whether function calling is enabled. Default is True.
|
||||
codeact_enable_browsing_delegate: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
|
||||
codeact_enable_browsing: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
|
||||
codeact_enable_llm_editor: Whether LLM editor is enabled in the action space. Default is False. Only works with function calling.
|
||||
codeact_enable_jupyter: Whether Jupyter is enabled in the action space. Default is False.
|
||||
micro_agent_name: The name of the micro agent to use for this agent.
|
||||
@@ -19,7 +19,7 @@ class AgentConfig:
|
||||
"""
|
||||
|
||||
function_calling: bool = True
|
||||
codeact_enable_browsing_delegate: bool = True
|
||||
codeact_enable_browsing: bool = True
|
||||
codeact_enable_llm_editor: bool = False
|
||||
codeact_enable_jupyter: bool = True
|
||||
micro_agent_name: str | None = None
|
||||
|
||||
@@ -49,6 +49,8 @@ class ImageContent(Content):
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# NOTE: this is not the same as EventSource
|
||||
# These are the roles in the LLM's APIs
|
||||
role: Literal['user', 'system', 'assistant', 'tool']
|
||||
content: list[TextContent | ImageContent] = Field(default_factory=list)
|
||||
cache_enabled: bool = False
|
||||
|
||||
@@ -9,6 +9,7 @@ from openhands.llm.metrics import Metrics
|
||||
class EventSource(str, Enum):
|
||||
AGENT = 'agent'
|
||||
USER = 'user'
|
||||
ENVIRONMENT = 'environment'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from browsergym.utils.obs import flatten_axtree_to_str
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
@@ -29,7 +31,7 @@ class BrowserOutputObservation(Observation):
|
||||
return 'Visited ' + self.url
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
ret = (
|
||||
'**BrowserOutputObservation**\n'
|
||||
f'URL: {self.url}\n'
|
||||
f'Error: {self.error}\n'
|
||||
@@ -38,5 +40,47 @@ class BrowserOutputObservation(Observation):
|
||||
f'Last browser action: {self.last_browser_action}\n'
|
||||
f'Last browser action error: {self.last_browser_action_error}\n'
|
||||
f'Focused element bid: {self.focused_element_bid}\n'
|
||||
f'CONTENT: {self.content}\n'
|
||||
f'Content: {self.content}\n'
|
||||
)
|
||||
ret += '--- Agent Observation ---\n'
|
||||
ret += self.get_agent_obs_text()
|
||||
return ret
|
||||
|
||||
def get_agent_obs_text(self) -> str:
|
||||
"""Get a concise text that will be shown to the agent."""
|
||||
text = f'[Current URL: {self.url}]\n'
|
||||
text += f'[Focused element bid: {self.focused_element_bid}]\n\n'
|
||||
if self.error:
|
||||
text += (
|
||||
'================ BEGIN error message ===============\n'
|
||||
'The following error occurred when executing the last action:\n'
|
||||
f'{self.last_browser_action_error}\n'
|
||||
'================ END error message ===============\n'
|
||||
)
|
||||
else:
|
||||
text += '[Action executed successfully.]\n'
|
||||
|
||||
try:
|
||||
# We do not filter visible only here because we want to show the full content
|
||||
# of the web page to the agent for simplicity.
|
||||
# FIXME: handle the case when the web page is too large
|
||||
cur_axtree_txt = self.get_axtree_str(filter_visible_only=False)
|
||||
text += (
|
||||
f'============== BEGIN accessibility tree ==============\n'
|
||||
f'{cur_axtree_txt}\n'
|
||||
f'============== END accessibility tree ==============\n'
|
||||
)
|
||||
except Exception as e:
|
||||
text += f'\n[Error encountered when processing the accessibility tree: {e}]'
|
||||
return text
|
||||
|
||||
def get_axtree_str(self, filter_visible_only: bool = False) -> str:
|
||||
cur_axtree_txt = flatten_axtree_to_str(
|
||||
self.axtree_object,
|
||||
extra_properties=self.extra_element_properties,
|
||||
with_clickable=True,
|
||||
skip_generic=False,
|
||||
filter_visible_only=filter_visible_only,
|
||||
)
|
||||
self._axtree_str = cur_axtree_txt
|
||||
return cur_axtree_txt
|
||||
|
||||
@@ -11,6 +11,7 @@ from openhands.events.event import Event, EventSource
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
from openhands.storage import FileStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class EventStreamSubscriber(str, Enum):
|
||||
@@ -22,14 +23,29 @@ class EventStreamSubscriber(str, Enum):
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
def session_exists(sid: str, file_store: FileStore) -> bool:
|
||||
async def session_exists(sid: str, file_store: FileStore) -> bool:
|
||||
try:
|
||||
file_store.list(f'sessions/{sid}')
|
||||
await call_sync_from_async(file_store.list, f'sessions/{sid}')
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
class AsyncEventStreamWrapper:
|
||||
def __init__(self, event_stream, *args, **kwargs):
|
||||
self.event_stream = event_stream
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aiter__(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an async generator that yields events
|
||||
for event in self.event_stream.get_events(*self.args, **self.kwargs):
|
||||
# Run the blocking get_events() in a thread pool
|
||||
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventStream:
|
||||
sid: str
|
||||
@@ -71,7 +87,15 @@ class EventStream:
|
||||
end_id=None,
|
||||
reverse=False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
filter_hidden=False,
|
||||
) -> Iterable[Event]:
|
||||
def should_filter(event: Event):
|
||||
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
|
||||
return True
|
||||
if filter_out_type is not None and isinstance(event, filter_out_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
if reverse:
|
||||
if end_id is None:
|
||||
end_id = self._cur_id - 1
|
||||
@@ -79,9 +103,7 @@ class EventStream:
|
||||
while event_id >= start_id:
|
||||
try:
|
||||
event = self.get_event(event_id)
|
||||
if filter_out_type is None or not isinstance(
|
||||
event, filter_out_type
|
||||
):
|
||||
if not should_filter(event):
|
||||
yield event
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No event found for ID {event_id}')
|
||||
@@ -93,9 +115,7 @@ class EventStream:
|
||||
break
|
||||
try:
|
||||
event = self.get_event(event_id)
|
||||
if filter_out_type is None or not isinstance(
|
||||
event, filter_out_type
|
||||
):
|
||||
if not should_filter(event):
|
||||
yield event
|
||||
except FileNotFoundError:
|
||||
break
|
||||
|
||||
+108
-80
@@ -82,6 +82,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: The LLM configuration.
|
||||
metrics: The metrics to use.
|
||||
"""
|
||||
self._tried_model_info = False
|
||||
self.metrics: Metrics = (
|
||||
metrics if metrics is not None else Metrics(model_name=config.model)
|
||||
)
|
||||
@@ -91,56 +92,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
self.model_info: ModelInfo | None = None
|
||||
|
||||
try:
|
||||
if self.config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.config.model)
|
||||
except Exception as e:
|
||||
logger.debug(f'Error getting model info: {e}')
|
||||
|
||||
if self.config.model.startswith('litellm_proxy/'):
|
||||
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
|
||||
# GET {base_url}/v1/model/info with litellm_model_id as path param
|
||||
response = requests.get(
|
||||
f'{self.config.base_url}/v1/model/info',
|
||||
headers={'Authorization': f'Bearer {self.config.api_key}'},
|
||||
)
|
||||
resp_json = response.json()
|
||||
if 'data' not in resp_json:
|
||||
logger.error(
|
||||
f'Error getting model info from LiteLLM proxy: {resp_json}'
|
||||
)
|
||||
all_model_info = resp_json.get('data', [])
|
||||
current_model_info = next(
|
||||
(
|
||||
info
|
||||
for info in all_model_info
|
||||
if info['model_name']
|
||||
== self.config.model.removeprefix('litellm_proxy/')
|
||||
),
|
||||
None,
|
||||
)
|
||||
if current_model_info:
|
||||
self.model_info = current_model_info['model_info']
|
||||
|
||||
# Last two attempts to get model info from NAME
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split(':')[0]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split('/')[-1]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(f'Model info: {self.model_info}')
|
||||
|
||||
if self.config.log_completions:
|
||||
if self.config.log_completions_folder is None:
|
||||
raise RuntimeError(
|
||||
@@ -148,32 +99,26 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
os.makedirs(self.config.log_completions_folder, exist_ok=True)
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_input_tokens'], int)
|
||||
):
|
||||
self.config.max_input_tokens = self.model_info['max_input_tokens']
|
||||
else:
|
||||
# Safe fallback for any potentially viable model
|
||||
self.config.max_input_tokens = 4096
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
max_tokens=self.config.max_output_tokens,
|
||||
timeout=self.config.timeout,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
drop_params=self.config.drop_params,
|
||||
)
|
||||
|
||||
if self.config.max_output_tokens is None:
|
||||
# Safe default for any potentially viable model
|
||||
self.config.max_output_tokens = 4096
|
||||
if self.model_info is not None:
|
||||
# max_output_tokens has precedence over max_tokens, if either exists.
|
||||
# litellm has models with both, one or none of these 2 parameters!
|
||||
if 'max_output_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_output_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_output_tokens']
|
||||
elif 'max_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
if self.vision_is_active():
|
||||
logger.debug('LLM: model has vision enabled')
|
||||
if self.is_caching_prompt_active():
|
||||
logger.debug('LLM: caching prompt enabled')
|
||||
if self.is_function_calling_active():
|
||||
logger.debug('LLM: model supports function calling')
|
||||
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
@@ -207,6 +152,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
self.init_model_info()
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@@ -300,6 +246,87 @@ class LLM(RetryMixin, DebugMixin):
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
def init_model_info(self):
|
||||
if self._tried_model_info:
|
||||
return
|
||||
self._tried_model_info = True
|
||||
try:
|
||||
if self.config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.config.model)
|
||||
except Exception as e:
|
||||
logger.debug(f'Error getting model info: {e}')
|
||||
|
||||
if self.config.model.startswith('litellm_proxy/'):
|
||||
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
|
||||
# GET {base_url}/v1/model/info with litellm_model_id as path param
|
||||
response = requests.get(
|
||||
f'{self.config.base_url}/v1/model/info',
|
||||
headers={'Authorization': f'Bearer {self.config.api_key}'},
|
||||
)
|
||||
resp_json = response.json()
|
||||
if 'data' not in resp_json:
|
||||
logger.error(
|
||||
f'Error getting model info from LiteLLM proxy: {resp_json}'
|
||||
)
|
||||
all_model_info = resp_json.get('data', [])
|
||||
current_model_info = next(
|
||||
(
|
||||
info
|
||||
for info in all_model_info
|
||||
if info['model_name']
|
||||
== self.config.model.removeprefix('litellm_proxy/')
|
||||
),
|
||||
None,
|
||||
)
|
||||
if current_model_info:
|
||||
self.model_info = current_model_info['model_info']
|
||||
|
||||
# Last two attempts to get model info from NAME
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split(':')[0]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split('/')[-1]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(f'Model info: {self.model_info}')
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_input_tokens'], int)
|
||||
):
|
||||
self.config.max_input_tokens = self.model_info['max_input_tokens']
|
||||
else:
|
||||
# Safe fallback for any potentially viable model
|
||||
self.config.max_input_tokens = 4096
|
||||
|
||||
if self.config.max_output_tokens is None:
|
||||
# Safe default for any potentially viable model
|
||||
self.config.max_output_tokens = 4096
|
||||
if self.model_info is not None:
|
||||
# max_output_tokens has precedence over max_tokens, if either exists.
|
||||
# litellm has models with both, one or none of these 2 parameters!
|
||||
if 'max_output_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_output_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_output_tokens']
|
||||
elif 'max_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
|
||||
def vision_is_active(self):
|
||||
return not self.config.disable_vision and self._supports_vision()
|
||||
|
||||
@@ -324,14 +351,15 @@ class LLM(RetryMixin, DebugMixin):
|
||||
Returns:
|
||||
boolean: True if prompt caching is supported and enabled for the given model.
|
||||
"""
|
||||
return (
|
||||
self.config.caching_prompt is True
|
||||
and self.model_info is not None
|
||||
and self.model_info.get('supports_prompt_caching', False)
|
||||
and (
|
||||
return self.config.caching_prompt is True and (
|
||||
(
|
||||
self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
|
||||
or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
|
||||
)
|
||||
or (
|
||||
self.model_info is not None
|
||||
and self.model_info.get('supports_prompt_caching', False)
|
||||
)
|
||||
)
|
||||
|
||||
def is_function_calling_active(self) -> bool:
|
||||
|
||||
@@ -12,6 +12,7 @@ from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import FatalErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.stream import EventStream
|
||||
@@ -33,6 +34,7 @@ class ShortTermHistory(list[Event]):
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
FatalErrorObservation,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -136,6 +136,8 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
observation._cause = event.id # type: ignore[attr-defined]
|
||||
observation.tool_call_metadata = event.tool_call_metadata
|
||||
|
||||
# this might be unnecessary, since source should be set by the event stream when we're here
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@@ -81,7 +81,10 @@ class BrowserEnv:
|
||||
raise ValueError(
|
||||
f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
|
||||
)
|
||||
env = gym.make(self.browsergym_eval_env)
|
||||
env = gym.make(
|
||||
self.browsergym_eval_env,
|
||||
tags_to_mark='all',
|
||||
)
|
||||
else:
|
||||
env = gym.make(
|
||||
'browsergym/openended',
|
||||
@@ -89,6 +92,7 @@ class BrowserEnv:
|
||||
wait_for_user_message=False,
|
||||
headless=True,
|
||||
disable_env_checker=True,
|
||||
tags_to_mark='all',
|
||||
)
|
||||
|
||||
obs, info = env.reset()
|
||||
|
||||
@@ -17,7 +17,7 @@ class DockerRuntimeBuilder(RuntimeBuilder):
|
||||
|
||||
version_info = self.docker_client.version()
|
||||
server_version = version_info.get('Version', '').replace('-', '.')
|
||||
if tuple(map(int, server_version.split('.'))) < (18, 9):
|
||||
if tuple(map(int, server_version.split('.')[:2])) < (18, 9):
|
||||
raise RuntimeError('Docker server version must be >= 18.09 to use BuildKit')
|
||||
|
||||
self.rolling_logger = RollingLogger(max_lines=10)
|
||||
|
||||
@@ -4,9 +4,7 @@ import tarfile
|
||||
from glob import glob
|
||||
|
||||
from e2b import Sandbox as E2BSandbox
|
||||
from e2b.sandbox.exception import (
|
||||
TimeoutException,
|
||||
)
|
||||
from e2b.sandbox.exception import TimeoutException
|
||||
|
||||
from openhands.core.config import SandboxConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -240,7 +240,7 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
wait=tenacity.wait_fixed(5),
|
||||
)
|
||||
def _init_container(self):
|
||||
try:
|
||||
|
||||
@@ -46,13 +46,13 @@ class EditTool:
|
||||
if command == 'view':
|
||||
return self.view(_path, view_range)
|
||||
elif command == 'create':
|
||||
if not file_text:
|
||||
if file_text is None:
|
||||
raise ToolError('Parameter `file_text` is required for command: create')
|
||||
self.write_file(_path, file_text)
|
||||
self._file_history[_path].append(file_text)
|
||||
return ToolResult(output=f'File created successfully at: {_path}')
|
||||
elif command == 'str_replace':
|
||||
if not old_str:
|
||||
if old_str is None:
|
||||
raise ToolError(
|
||||
'Parameter `old_str` is required for command: str_replace'
|
||||
)
|
||||
@@ -62,7 +62,7 @@ class EditTool:
|
||||
raise ToolError(
|
||||
'Parameter `insert_line` is required for command: insert'
|
||||
)
|
||||
if not new_str:
|
||||
if new_str is None:
|
||||
raise ToolError('Parameter `new_str` is required for command: insert')
|
||||
return self.insert(_path, insert_line, new_str)
|
||||
elif command == 'undo_edit':
|
||||
|
||||
@@ -87,6 +87,7 @@ COPY ./code/pyproject.toml ./code/poetry.lock /openhands/code/
|
||||
RUN if [ -d /openhands/code/openhands ]; then rm -rf /openhands/code/openhands; fi
|
||||
COPY ./code/pyproject.toml ./code/poetry.lock /openhands/code/
|
||||
COPY ./code/openhands /openhands/code/openhands
|
||||
RUN chmod a+rwx /openhands/code/openhands/__init__.py
|
||||
|
||||
# ================================================================
|
||||
# END: Build from versioned image
|
||||
|
||||
@@ -147,6 +147,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
new_event = action_from_dict(
|
||||
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
|
||||
)
|
||||
# we should confirm only on agent actions
|
||||
event_source = event.source if event.source else EventSource.AGENT
|
||||
await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
@@ -10,10 +10,12 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class FeedbackDataModel(BaseModel):
|
||||
version: str
|
||||
email: str
|
||||
token: str
|
||||
feedback: Literal['positive', 'negative']
|
||||
polarity: Literal['positive', 'negative']
|
||||
feedback: Literal[
|
||||
'positive', 'negative'
|
||||
] # TODO: remove this, its here for backward compatibility
|
||||
permissions: Literal['public', 'private']
|
||||
trajectory: list[dict[str, Any]]
|
||||
trajectory: Optional[list[dict[str, Any]]]
|
||||
|
||||
|
||||
FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_trajectory'
|
||||
@@ -21,6 +23,7 @@ FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_tra
|
||||
|
||||
def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
|
||||
# Start logging
|
||||
feedback.feedback = feedback.polarity
|
||||
display_feedback = feedback.model_dump()
|
||||
if 'trajectory' in display_feedback:
|
||||
display_feedback['trajectory'] = (
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.sheets_client import GoogleSheetsClient
|
||||
|
||||
GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
|
||||
GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
|
||||
|
||||
|
||||
class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.info('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured"""
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if not waitlist:
|
||||
logger.info('GITHUB_USER_LIST_FILE not configured')
|
||||
return
|
||||
|
||||
if not os.path.exists(waitlist):
|
||||
logger.error(f'User list file not found: {waitlist}')
|
||||
raise FileNotFoundError(f'User list file not found: {waitlist}')
|
||||
|
||||
try:
|
||||
with open(waitlist, 'r') as f:
|
||||
self.file_users = [line.strip() for line in f if line.strip()]
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error reading user list file {waitlist}: {str(e)}')
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured"""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.info('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.info('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration"""
|
||||
if not self.is_active():
|
||||
return True
|
||||
|
||||
logger.info(f'Checking if GitHub user {username} is allowed')
|
||||
if self.file_users:
|
||||
if username in self.file_users:
|
||||
logger.info(f'User {username} found in text file allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
if username in sheet_users:
|
||||
logger.info(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.info(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
async def authenticate_github_user(auth_token) -> bool:
|
||||
user_verifier = UserVerifier()
|
||||
|
||||
if not user_verifier.is_active():
|
||||
logger.info('No user verification sources configured - allowing all users')
|
||||
return True
|
||||
|
||||
logger.info('Checking GitHub token')
|
||||
|
||||
if not auth_token:
|
||||
logger.warning('No GitHub token provided')
|
||||
return False
|
||||
|
||||
login = await get_github_user(auth_token)
|
||||
|
||||
if not user_verifier.is_user_allowed(login):
|
||||
logger.warning(f'GitHub user {login} not in allow list')
|
||||
return False
|
||||
|
||||
logger.info(f'GitHub user {login} authenticated')
|
||||
return True
|
||||
|
||||
|
||||
async def get_github_user(token: str) -> str:
|
||||
"""Get GitHub user info from token.
|
||||
|
||||
Args:
|
||||
token: GitHub access token
|
||||
|
||||
Returns:
|
||||
Tuple of (login, error_message)
|
||||
If successful, error_message is None
|
||||
If failed, login is None and error_message contains the error
|
||||
"""
|
||||
logger.info('Fetching GitHub user info from token')
|
||||
headers = {
|
||||
'Accept': 'application/vnd.github+json',
|
||||
'Authorization': f'Bearer {token}',
|
||||
'X-GitHub-Api-Version': '2022-11-28',
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.debug('Making request to GitHub API')
|
||||
response = await client.get('https://api.github.com/user', headers=headers)
|
||||
response.raise_for_status()
|
||||
user_data = response.json()
|
||||
login = user_data.get('login')
|
||||
logger.info(f'Successfully retrieved GitHub user: {login}')
|
||||
return login
|
||||
+74
-52
@@ -13,6 +13,11 @@ from pathspec.patterns import GitWildMatchPattern
|
||||
|
||||
from openhands.security.options import SecurityAnalyzers
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.server.github import (
|
||||
GITHUB_CLIENT_ID,
|
||||
GITHUB_CLIENT_SECRET,
|
||||
authenticate_github_user,
|
||||
)
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
@@ -52,6 +57,7 @@ from openhands.events.observation import (
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.llm import bedrock
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.auth import get_sid_from_token, sign_token
|
||||
@@ -64,24 +70,6 @@ config = load_app_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
session_manager = SessionManager(config, file_store)
|
||||
|
||||
GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
|
||||
GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
|
||||
|
||||
# New global variable to store the user list
|
||||
GITHUB_USER_LIST = None
|
||||
|
||||
|
||||
# New function to load the user list
|
||||
def load_github_user_list():
|
||||
global GITHUB_USER_LIST
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if waitlist:
|
||||
with open(waitlist, 'r') as f:
|
||||
GITHUB_USER_LIST = [line.strip() for line in f if line.strip()]
|
||||
|
||||
|
||||
load_github_user_list()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -216,7 +204,13 @@ async def attach_session(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# For all other methods, validate the Authorization header
|
||||
github_token = request.headers.get('X-GitHub-Token')
|
||||
if not await authenticate_github_user(github_token):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Not authenticated'},
|
||||
)
|
||||
|
||||
if not request.headers.get('Authorization'):
|
||||
logger.warning('Missing Authorization header')
|
||||
return JSONResponse(
|
||||
@@ -308,11 +302,28 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
{"action": "finish", "args": {}}
|
||||
```
|
||||
"""
|
||||
await asyncio.wait_for(websocket.accept(), 10)
|
||||
# Get protocols from Sec-WebSocket-Protocol header
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
|
||||
|
||||
if websocket.query_params.get('token'):
|
||||
token = websocket.query_params.get('token')
|
||||
sid = get_sid_from_token(token, config.jwt_secret)
|
||||
# The first protocol should be our real protocol (e.g. 'openhands')
|
||||
# The second protocol should contain our auth token
|
||||
if len(protocols) < 3:
|
||||
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
real_protocol = protocols[0]
|
||||
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
|
||||
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
|
||||
|
||||
if not await authenticate_github_user(github_token):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
|
||||
|
||||
if jwt_token:
|
||||
sid = get_sid_from_token(jwt_token, config.jwt_secret)
|
||||
|
||||
if sid == '':
|
||||
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
||||
@@ -320,18 +331,21 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
return
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
|
||||
logger.info(f'New session: {sid}')
|
||||
session = session_manager.add_or_restart_session(sid, websocket)
|
||||
await websocket.send_json({'token': token, 'status': 'ok'})
|
||||
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
|
||||
|
||||
latest_event_id = -1
|
||||
if websocket.query_params.get('latest_event_id'):
|
||||
latest_event_id = int(websocket.query_params.get('latest_event_id'))
|
||||
for event in session.agent_session.event_stream.get_events(
|
||||
start_id=latest_event_id + 1
|
||||
):
|
||||
|
||||
async_stream = AsyncEventStreamWrapper(
|
||||
session.agent_session.event_stream, latest_event_id + 1
|
||||
)
|
||||
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
@@ -469,19 +483,17 @@ async def list_files(request: Request, path: str | None = None):
|
||||
)
|
||||
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
file_list = await asyncio.create_task(
|
||||
call_sync_from_async(runtime.list_files, path)
|
||||
)
|
||||
file_list = await call_sync_from_async(runtime.list_files, path)
|
||||
if path:
|
||||
file_list = [os.path.join(path, f) for f in file_list]
|
||||
|
||||
file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
|
||||
|
||||
def filter_for_gitignore(file_list, base_path):
|
||||
async def filter_for_gitignore(file_list, base_path):
|
||||
gitignore_path = os.path.join(base_path, '.gitignore')
|
||||
try:
|
||||
read_action = FileReadAction(gitignore_path)
|
||||
observation = runtime.run_action(read_action)
|
||||
observation = await call_sync_from_async(runtime.run_action, read_action)
|
||||
spec = PathSpec.from_lines(
|
||||
GitWildMatchPattern, observation.content.splitlines()
|
||||
)
|
||||
@@ -491,7 +503,7 @@ async def list_files(request: Request, path: str | None = None):
|
||||
file_list = [entry for entry in file_list if not spec.match_file(entry)]
|
||||
return file_list
|
||||
|
||||
file_list = filter_for_gitignore(file_list, '')
|
||||
file_list = await filter_for_gitignore(file_list, '')
|
||||
|
||||
return file_list
|
||||
|
||||
@@ -634,14 +646,14 @@ async def upload_file(request: Request, files: list[UploadFile]):
|
||||
|
||||
|
||||
@app.post('/api/submit-feedback')
|
||||
async def submit_feedback(request: Request, feedback: FeedbackDataModel):
|
||||
async def submit_feedback(request: Request):
|
||||
"""Submit user feedback.
|
||||
|
||||
This function stores the provided feedback data.
|
||||
|
||||
To submit feedback:
|
||||
```sh
|
||||
curl -X POST -F "email=test@example.com" -F "token=abc" -F "feedback=positive" -F "permissions=private" -F "trajectory={}" http://localhost:3000/api/submit-feedback
|
||||
curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:"
|
||||
```
|
||||
|
||||
Args:
|
||||
@@ -656,8 +668,23 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel):
|
||||
"""
|
||||
# Assuming the storage service is already configured in the backend
|
||||
# and there is a function to handle the storage.
|
||||
body = await request.json()
|
||||
async_stream = AsyncEventStreamWrapper(
|
||||
request.state.conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
trajectory = []
|
||||
async for event in async_stream:
|
||||
trajectory.append(event_to_dict(event))
|
||||
feedback = FeedbackDataModel(
|
||||
email=body.get('email', ''),
|
||||
version=body.get('version', ''),
|
||||
permissions=body.get('permissions', 'private'),
|
||||
polarity=body.get('polarity', ''),
|
||||
feedback=body.get('polarity', ''),
|
||||
trajectory=trajectory,
|
||||
)
|
||||
try:
|
||||
feedback_data = store_feedback(feedback)
|
||||
feedback_data = await call_sync_from_async(store_feedback, feedback)
|
||||
return JSONResponse(status_code=200, content=feedback_data)
|
||||
except Exception as e:
|
||||
logger.error(f'Error submitting feedback: {e}')
|
||||
@@ -827,26 +854,21 @@ def github_callback(auth_code: AuthCode):
|
||||
)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
login: str # GitHub login handle
|
||||
|
||||
|
||||
@app.post('/api/authenticate')
|
||||
def authenticate(user: User | None = None):
|
||||
global GITHUB_USER_LIST
|
||||
async def authenticate(request: Request):
|
||||
token = request.headers.get('X-GitHub-Token')
|
||||
if not await authenticate_github_user(token):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Not authorized via GitHub waitlist'},
|
||||
)
|
||||
|
||||
# Only check if waitlist is provided
|
||||
if GITHUB_USER_LIST:
|
||||
if user is None or user.login not in GITHUB_USER_LIST:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={'error': 'User not on waitlist'},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SPAStaticFiles(StaticFiles):
|
||||
async def get_response(self, path: str, scope):
|
||||
|
||||
@@ -14,7 +14,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
||||
super().__init__(app, **kwargs)
|
||||
|
||||
async def is_allowed_origin(self, origin: str) -> bool:
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if origin:
|
||||
parsed = urlparse(origin)
|
||||
hostname = parsed.hostname or ''
|
||||
@@ -24,7 +24,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
return True
|
||||
|
||||
# For missing origin or other origins, use the parent class's logic
|
||||
return await super().is_allowed_origin(origin)
|
||||
return super().is_allowed_origin(origin)
|
||||
|
||||
|
||||
class NoCacheMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
@@ -101,7 +101,6 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(
|
||||
runtime_name=runtime_name,
|
||||
@@ -118,15 +117,20 @@ class AgentSession:
|
||||
agent_configs=agent_configs,
|
||||
)
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
|
||||
)
|
||||
if self.controller:
|
||||
self.controller.agent_task = self.controller.start_step_loop()
|
||||
await self.controller.agent_task # type: ignore
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
"""Closes the Agent session"""
|
||||
self._closed = True
|
||||
def inner_close():
|
||||
asyncio.run(self._close())
|
||||
asyncio.get_event_loop().run_in_executor(None, inner_close)
|
||||
|
||||
async def _close(self):
|
||||
if self._closed:
|
||||
return
|
||||
if self.controller is not None:
|
||||
@@ -138,10 +142,6 @@ class AgentSession:
|
||||
if self.security_analyzer is not None:
|
||||
await self.security_analyzer.close()
|
||||
|
||||
if self.loop:
|
||||
self.loop.stop()
|
||||
|
||||
self._closed = True
|
||||
|
||||
def _create_security_analyzer(self, security_analyzer: str | None):
|
||||
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
|
||||
|
||||
@@ -35,7 +35,7 @@ class SessionManager:
|
||||
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
if sid in self._sessions:
|
||||
asyncio.create_task(self._sessions[sid].close())
|
||||
self._sessions[sid].close()
|
||||
self._sessions[sid] = Session(
|
||||
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
|
||||
)
|
||||
@@ -47,7 +47,7 @@ class SessionManager:
|
||||
return self._sessions.get(sid)
|
||||
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
if not session_exists(sid, self.file_store):
|
||||
if not await session_exists(sid, self.file_store):
|
||||
return None
|
||||
c = Conversation(sid, file_store=self.file_store, config=self.config)
|
||||
await c.connect()
|
||||
@@ -87,7 +87,7 @@ class SessionManager:
|
||||
for sid in session_ids_to_remove:
|
||||
to_del_session: Session | None = self._sessions.pop(sid, None)
|
||||
if to_del_session is not None:
|
||||
await to_del_session.close()
|
||||
to_del_session.close()
|
||||
logger.debug(
|
||||
f'Session {sid} and related resource have been removed due to inactivity.'
|
||||
)
|
||||
|
||||
@@ -25,8 +25,6 @@ from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
DEL_DELT_SEC = 60 * 60 * 5
|
||||
|
||||
|
||||
class Session:
|
||||
sid: str
|
||||
@@ -49,9 +47,9 @@ class Session:
|
||||
self.config = config
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
self.agent_session.close()
|
||||
|
||||
async def loop_recv(self):
|
||||
try:
|
||||
@@ -65,18 +63,19 @@ class Session:
|
||||
continue
|
||||
await self.dispatch(data)
|
||||
except WebSocketDisconnect:
|
||||
await self.close()
|
||||
logger.debug('WebSocket disconnected, sid: %s', self.sid)
|
||||
logger.info('WebSocket disconnected, sid: %s', self.sid)
|
||||
self.close()
|
||||
except RuntimeError as e:
|
||||
await self.close()
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
self.close()
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
|
||||
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
|
||||
)
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
# Extract the agent-relevant arguments from the request
|
||||
args = {key: value for key, value in data.get('args', {}).items()}
|
||||
@@ -138,12 +137,19 @@ class Session:
|
||||
return
|
||||
if event.source == EventSource.AGENT:
|
||||
await self.send(event_to_dict(event))
|
||||
elif event.source == EventSource.USER and isinstance(
|
||||
event, CmdOutputObservation
|
||||
# NOTE: ipython observations are not sent here currently
|
||||
elif event.source == EventSource.ENVIRONMENT and isinstance(
|
||||
event, (CmdOutputObservation, AgentStateChangedObservation)
|
||||
):
|
||||
await self.send(event_to_dict(event))
|
||||
# feedback from the environment to agent actions is understood as agent events by the UI
|
||||
event_dict = event_to_dict(event)
|
||||
event_dict['source'] = EventSource.AGENT
|
||||
await self.send(event_dict)
|
||||
elif isinstance(event, ErrorObservation):
|
||||
await self.send(event_to_dict(event))
|
||||
# send error events as agent events to the UI
|
||||
event_dict = event_to_dict(event)
|
||||
event_dict['source'] = EventSource.AGENT
|
||||
await self.send(event_dict)
|
||||
|
||||
async def dispatch(self, data: dict):
|
||||
action = data.get('action', '')
|
||||
@@ -165,10 +171,12 @@ class Session:
|
||||
'Model does not support image upload, change to a different model or try without an image.'
|
||||
)
|
||||
return
|
||||
if self.agent_session.loop:
|
||||
if self.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER), self.agent_session.loop
|
||||
self._add_event(event, EventSource.USER), self.loop
|
||||
) # type: ignore
|
||||
else:
|
||||
raise RuntimeError('No event loop found')
|
||||
|
||||
async def _add_event(self, event, event_source):
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
@@ -192,26 +200,10 @@ class Session:
|
||||
"""Sends an error message to the client."""
|
||||
return await self.send({'error': True, 'message': message})
|
||||
|
||||
async def send_message(self, message: str) -> bool:
|
||||
"""Sends a message to the client."""
|
||||
return await self.send({'message': message})
|
||||
|
||||
async def send_status_message(self, message: str) -> bool:
|
||||
"""Sends a status message to the client."""
|
||||
return await self.send({'status': message})
|
||||
|
||||
def update_connection(self, ws: WebSocket):
|
||||
self.websocket = ws
|
||||
self.is_alive = True
|
||||
self.last_active_ts = int(time.time())
|
||||
|
||||
def load_from_data(self, data: dict) -> bool:
|
||||
self.last_active_ts = data.get('last_active_ts', 0)
|
||||
if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
|
||||
return False
|
||||
self.is_alive = data.get('is_alive', False)
|
||||
return True
|
||||
|
||||
def queue_status_message(self, message: str):
|
||||
"""Queues a status message to be sent asynchronously."""
|
||||
# Ensure the coroutine runs in the main event loop
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from typing import List
|
||||
|
||||
from google.auth import default
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class GoogleSheetsClient:
|
||||
def __init__(self):
|
||||
"""Initialize Google Sheets client using workload identity.
|
||||
Uses application default credentials which supports workload identity when running in GCP.
|
||||
"""
|
||||
logger.info('Initializing Google Sheets client with workload identity')
|
||||
try:
|
||||
credentials, project = default(
|
||||
scopes=['https://www.googleapis.com/auth/spreadsheets.readonly']
|
||||
)
|
||||
logger.info(f'Successfully obtained credentials for project: {project}')
|
||||
self.service = build('sheets', 'v4', credentials=credentials)
|
||||
logger.info('Successfully initialized Google Sheets API service')
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Google Sheets client: {str(e)}')
|
||||
self.service = None
|
||||
|
||||
def get_usernames(self, spreadsheet_id: str, range_name: str = 'A:A') -> List[str]:
|
||||
"""Get list of usernames from specified Google Sheet.
|
||||
|
||||
Args:
|
||||
spreadsheet_id: The ID of the Google Sheet
|
||||
range_name: The A1 notation of the range to fetch
|
||||
|
||||
Returns:
|
||||
List of usernames from the sheet
|
||||
"""
|
||||
if not self.service:
|
||||
logger.error('Google Sheets service not initialized')
|
||||
return []
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f'Fetching usernames from sheet {spreadsheet_id}, range {range_name}'
|
||||
)
|
||||
result = (
|
||||
self.service.spreadsheets()
|
||||
.values()
|
||||
.get(spreadsheetId=spreadsheet_id, range=range_name)
|
||||
.execute()
|
||||
)
|
||||
|
||||
values = result.get('values', [])
|
||||
usernames = [
|
||||
str(cell[0]).strip() for cell in values if cell and cell[0].strip()
|
||||
]
|
||||
logger.info(
|
||||
f'Successfully fetched {len(usernames)} usernames from Google Sheet'
|
||||
)
|
||||
return usernames
|
||||
|
||||
except HttpError as err:
|
||||
logger.error(f'Error accessing Google Sheet {spreadsheet_id}: {err}')
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Unexpected error accessing Google Sheet {spreadsheet_id}: {str(e)}'
|
||||
)
|
||||
return []
|
||||
Generated
+20
-2
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@@ -2319,6 +2319,24 @@ files = [
|
||||
google-auth = "*"
|
||||
httplib2 = ">=0.19.0"
|
||||
|
||||
[[package]]
|
||||
name = "google-auth-oauthlib"
|
||||
version = "1.2.1"
|
||||
description = "Google Authentication Library"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "google_auth_oauthlib-1.2.1-py2.py3-none-any.whl", hash = "sha256:2d58a27262d55aa1b87678c3ba7142a080098cbc2024f903c62355deb235d91f"},
|
||||
{file = "google_auth_oauthlib-1.2.1.tar.gz", hash = "sha256:afd0cad092a2eaa53cd8e8298557d6de1034c6cb4a740500b5357b648af97263"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
google-auth = ">=2.15.0"
|
||||
requests-oauthlib = ">=0.7.0"
|
||||
|
||||
[package.extras]
|
||||
tool = ["click (>=6.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-cloud-aiplatform"
|
||||
version = "1.70.0"
|
||||
@@ -10109,4 +10127,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "2b268ef696ace0d8170276407dbdeb414134477839ebe4b7ecf29b1a1fe2cef3"
|
||||
content-hash = "2a4f90bb5c7f7d82160f57d71af7e81c7acef69426d0e1e46e1da09972a6215f"
|
||||
|
||||
+4
-3
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "openhands-ai"
|
||||
version = "0.11.0"
|
||||
version = "0.12.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
authors = ["OpenHands"]
|
||||
license = "MIT"
|
||||
@@ -16,6 +16,9 @@ datasets = "*"
|
||||
pandas = "*"
|
||||
litellm = "^1.51.1"
|
||||
google-generativeai = "*" # To use litellm with Gemini Pro API
|
||||
google-api-python-client = "*" # For Google Sheets API
|
||||
google-auth-httplib2 = "*" # For Google Sheets authentication
|
||||
google-auth-oauthlib = "*" # For Google Sheets OAuth
|
||||
termcolor = "*"
|
||||
seaborn = "*"
|
||||
docker = "*"
|
||||
@@ -89,7 +92,6 @@ reportlab = "*"
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -120,7 +122,6 @@ ignore = ["D1"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
@@ -232,3 +232,85 @@ def test_ipython_package_install(temp_dir, runtime_cls, run_as_openhands):
|
||||
)
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_ipython_file_editor_permissions_as_openhands(temp_dir, runtime_cls):
|
||||
"""Test file editor permission behavior when running as different users."""
|
||||
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands=True)
|
||||
sandbox_dir = _get_sandbox_folder(runtime)
|
||||
|
||||
# Create a file owned by root with restricted permissions
|
||||
action = CmdRunAction(
|
||||
command='sudo touch /root/test.txt && sudo chmod 600 /root/test.txt'
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# Try to view the file as openhands user - should fail with permission denied
|
||||
test_code = "print(file_editor(command='view', path='/root/test.txt'))"
|
||||
action = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert 'Permission denied' in obs.content
|
||||
|
||||
# Try to edit the file as openhands user - should fail with permission denied
|
||||
test_code = "print(file_editor(command='str_replace', path='/root/test.txt', old_str='', new_str='test'))"
|
||||
action = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert 'Permission denied' in obs.content
|
||||
|
||||
# Try to create a file in root directory - should fail with permission denied
|
||||
test_code = (
|
||||
"print(file_editor(command='create', path='/root/new.txt', file_text='test'))"
|
||||
)
|
||||
action = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert 'Permission denied' in obs.content
|
||||
|
||||
# Try to use file editor in openhands sandbox directory - should work
|
||||
test_code = f"""
|
||||
# Create file
|
||||
print(file_editor(command='create', path='{sandbox_dir}/test.txt', file_text='Line 1\\nLine 2\\nLine 3'))
|
||||
|
||||
# View file
|
||||
print(file_editor(command='view', path='{sandbox_dir}/test.txt'))
|
||||
|
||||
# Edit file
|
||||
print(file_editor(command='str_replace', path='{sandbox_dir}/test.txt', old_str='Line 2', new_str='New Line 2'))
|
||||
|
||||
# Undo edit
|
||||
print(file_editor(command='undo_edit', path='{sandbox_dir}/test.txt'))
|
||||
"""
|
||||
action = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert 'File created successfully' in obs.content
|
||||
assert 'Line 1' in obs.content
|
||||
assert 'Line 2' in obs.content
|
||||
assert 'Line 3' in obs.content
|
||||
assert 'New Line 2' in obs.content
|
||||
assert 'Last edit to' in obs.content
|
||||
assert 'undone successfully' in obs.content
|
||||
|
||||
# Clean up
|
||||
action = CmdRunAction(command=f'rm -f {sandbox_dir}/test.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='sudo rm -f /root/test.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
@@ -207,7 +207,9 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
||||
'Non fatal error here to trigger loop'
|
||||
)
|
||||
non_fatal_error_obs._cause = event.id
|
||||
await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
|
||||
await event_stream.async_add_event(
|
||||
non_fatal_error_obs, EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
+37
-37
@@ -80,7 +80,7 @@ class TestStuckDetector:
|
||||
code=code_snippet,
|
||||
)
|
||||
ipython_observation._cause = ipython_action._id
|
||||
event_stream.add_event(ipython_observation, EventSource.USER)
|
||||
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
def _impl_unterminated_string_error_events(
|
||||
self, event_stream: EventStream, random_line: bool, incidents: int = 4
|
||||
@@ -96,7 +96,7 @@ class TestStuckDetector:
|
||||
code=code_snippet,
|
||||
)
|
||||
ipython_observation._cause = ipython_action._id
|
||||
event_stream.add_event(ipython_observation, EventSource.USER)
|
||||
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
def test_history_too_short(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
@@ -106,7 +106,7 @@ class TestStuckDetector:
|
||||
observation = NullObservation(content='')
|
||||
observation._cause = message_action.id
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(observation, EventSource.USER)
|
||||
event_stream.add_event(observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action, EventSource.AGENT)
|
||||
@@ -114,7 +114,7 @@ class TestStuckDetector:
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation._cause = cmd_action._id
|
||||
event_stream.add_event(cmd_observation, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
@@ -131,7 +131,7 @@ class TestStuckDetector:
|
||||
|
||||
# 2 events
|
||||
event_stream.add_event(hello_action, EventSource.USER)
|
||||
event_stream.add_event(hello_observation, EventSource.USER)
|
||||
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
@@ -139,7 +139,7 @@ class TestStuckDetector:
|
||||
content='', command='ls', command_id=cmd_action_1._id
|
||||
)
|
||||
cmd_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(cmd_observation_1, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
# 4 events
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='ls')
|
||||
@@ -148,13 +148,13 @@ class TestStuckDetector:
|
||||
content='', command='ls', command_id=cmd_action_2._id
|
||||
)
|
||||
cmd_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(cmd_observation_2, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
# 6 events
|
||||
|
||||
# random user message just because we can
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
# 8 events
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
@@ -166,7 +166,7 @@ class TestStuckDetector:
|
||||
content='', command='ls', command_id=cmd_action_3._id
|
||||
)
|
||||
cmd_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(cmd_observation_3, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
|
||||
# 10 events
|
||||
|
||||
assert len(collect_events(event_stream)) == 10
|
||||
@@ -191,7 +191,7 @@ class TestStuckDetector:
|
||||
content='', command='ls', command_id=cmd_action_4._id
|
||||
)
|
||||
cmd_observation_4._cause = cmd_action_4._id
|
||||
event_stream.add_event(cmd_observation_4, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT)
|
||||
# 12 events
|
||||
|
||||
assert len(collect_events(event_stream)) == 12
|
||||
@@ -223,14 +223,14 @@ class TestStuckDetector:
|
||||
hello_observation = NullObservation(content='')
|
||||
event_stream.add_event(hello_action, EventSource.USER)
|
||||
hello_observation._cause = hello_action._id
|
||||
event_stream.add_event(hello_observation, EventSource.USER)
|
||||
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
|
||||
# 2 events
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
error_observation_1 = ErrorObservation(content='Command not found')
|
||||
error_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(error_observation_1, EventSource.USER)
|
||||
event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT)
|
||||
# 4 events
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='invalid_command')
|
||||
@@ -239,26 +239,26 @@ class TestStuckDetector:
|
||||
content='Command still not found or another error'
|
||||
)
|
||||
error_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(error_observation_2, EventSource.USER)
|
||||
event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT)
|
||||
# 6 events
|
||||
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
# 8 events
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
error_observation_3 = ErrorObservation(content='Different error')
|
||||
error_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(error_observation_3, EventSource.USER)
|
||||
event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT)
|
||||
# 10 events
|
||||
|
||||
cmd_action_4 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_4, EventSource.AGENT)
|
||||
error_observation_4 = ErrorObservation(content='Command not found')
|
||||
error_observation_4._cause = cmd_action_4._id
|
||||
event_stream.add_event(error_observation_4, EventSource.USER)
|
||||
event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT)
|
||||
# 12 events
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
@@ -366,7 +366,7 @@ class TestStuckDetector:
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_1._cause = ipython_action_1._id
|
||||
event_stream.add_event(ipython_observation_1, EventSource.USER)
|
||||
event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT)
|
||||
|
||||
ipython_action_2 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_2, EventSource.AGENT)
|
||||
@@ -375,7 +375,7 @@ class TestStuckDetector:
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_2._cause = ipython_action_2._id
|
||||
event_stream.add_event(ipython_observation_2, EventSource.USER)
|
||||
event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
ipython_action_3 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_3, EventSource.AGENT)
|
||||
@@ -384,7 +384,7 @@ class TestStuckDetector:
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_3._cause = ipython_action_3._id
|
||||
event_stream.add_event(ipython_observation_3, EventSource.USER)
|
||||
event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT)
|
||||
|
||||
ipython_action_4 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_4, EventSource.AGENT)
|
||||
@@ -393,7 +393,7 @@ class TestStuckDetector:
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_4._cause = ipython_action_4._id
|
||||
event_stream.add_event(ipython_observation_4, EventSource.USER)
|
||||
event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is False
|
||||
@@ -406,7 +406,7 @@ class TestStuckDetector:
|
||||
message_action._source = EventSource.USER
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
message_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_observation, EventSource.USER)
|
||||
event_stream.add_event(message_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
@@ -414,7 +414,7 @@ class TestStuckDetector:
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(cmd_observation_1, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
|
||||
read_action_1 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_1, EventSource.AGENT)
|
||||
@@ -422,7 +422,7 @@ class TestStuckDetector:
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_1._cause = read_action_1._id
|
||||
event_stream.add_event(read_observation_1, EventSource.USER)
|
||||
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
@@ -430,7 +430,7 @@ class TestStuckDetector:
|
||||
command_id=2, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(cmd_observation_2, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
read_action_2 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_2, EventSource.AGENT)
|
||||
@@ -438,12 +438,12 @@ class TestStuckDetector:
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_2._cause = read_action_2._id
|
||||
event_stream.add_event(read_observation_2, EventSource.USER)
|
||||
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
# one more message to break the pattern
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
@@ -451,7 +451,7 @@ class TestStuckDetector:
|
||||
command_id=3, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(cmd_observation_3, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
|
||||
|
||||
read_action_3 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_3, EventSource.AGENT)
|
||||
@@ -459,7 +459,7 @@ class TestStuckDetector:
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_3._cause = read_action_3._id
|
||||
event_stream.add_event(read_observation_3, EventSource.USER)
|
||||
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
@@ -475,7 +475,7 @@ class TestStuckDetector:
|
||||
event_stream.add_event(hello_action, EventSource.USER)
|
||||
hello_observation = NullObservation(content='')
|
||||
hello_observation._cause = hello_action._id
|
||||
event_stream.add_event(hello_observation, EventSource.USER)
|
||||
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
@@ -483,7 +483,7 @@ class TestStuckDetector:
|
||||
command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(cmd_observation_1, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
|
||||
read_action_1 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_1, EventSource.AGENT)
|
||||
@@ -491,7 +491,7 @@ class TestStuckDetector:
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_1._cause = read_action_1._id
|
||||
event_stream.add_event(read_observation_1, EventSource.USER)
|
||||
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='pwd')
|
||||
event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
@@ -499,7 +499,7 @@ class TestStuckDetector:
|
||||
command_id=2, command='pwd', content='/home/user'
|
||||
)
|
||||
cmd_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(cmd_observation_2, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
read_action_2 = FileReadAction(path='file2.txt')
|
||||
event_stream.add_event(read_action_2, EventSource.AGENT)
|
||||
@@ -507,11 +507,11 @@ class TestStuckDetector:
|
||||
content='Another file content', path='file2.txt'
|
||||
)
|
||||
read_observation_2._cause = read_action_2._id
|
||||
event_stream.add_event(read_observation_2, EventSource.USER)
|
||||
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='pwd')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
@@ -519,7 +519,7 @@ class TestStuckDetector:
|
||||
command_id=cmd_action_3.id, command='pwd', content='/home/user'
|
||||
)
|
||||
cmd_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(cmd_observation_3, EventSource.USER)
|
||||
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
|
||||
|
||||
read_action_3 = FileReadAction(path='file2.txt')
|
||||
event_stream.add_event(read_action_3, EventSource.AGENT)
|
||||
@@ -527,7 +527,7 @@ class TestStuckDetector:
|
||||
content='Another file content', path='file2.txt'
|
||||
)
|
||||
read_observation_3._cause = read_action_3._id
|
||||
event_stream.add_event(read_observation_3, EventSource.USER)
|
||||
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
@@ -572,7 +572,7 @@ class TestStuckDetector:
|
||||
exit_code=0,
|
||||
)
|
||||
cmd_output_observation._cause = cmd_kill_action._id
|
||||
event_stream.add_event(cmd_output_observation, EventSource.USER)
|
||||
event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
|
||||
event_stream.add_event(message_action_7, EventSource.AGENT)
|
||||
|
||||
@@ -50,6 +50,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
'max_output_tokens': 2000,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 8000
|
||||
assert llm.config.max_output_tokens == 2000
|
||||
|
||||
@@ -58,6 +59,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
def test_llm_init_without_model_info(mock_get_model_info, default_config):
|
||||
mock_get_model_info.side_effect = Exception('Model info not available')
|
||||
llm = LLM(default_config)
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 4096
|
||||
assert llm.config.max_output_tokens == 4096
|
||||
|
||||
@@ -108,6 +110,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
'max_output_tokens': 1500,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 7000
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
|
||||
|
||||
@@ -88,7 +88,7 @@ def _create_observation_event(observation: str) -> Event:
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.USER
|
||||
event._source = EventSource.ENVIRONMENT
|
||||
event.observation = observation
|
||||
return event
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
command='ls -l',
|
||||
exit_code=0,
|
||||
)
|
||||
mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
|
||||
mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
|
||||
message_action_2 = MessageAction("Now, let's create a new directory.")
|
||||
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
|
||||
@@ -169,7 +169,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
command='mkdir new_directory',
|
||||
exit_code=0,
|
||||
)
|
||||
mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
|
||||
mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
|
||||
Reference in New Issue
Block a user