mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
129
evaluation/benchmarks/mint/env.py
Normal file
129
evaluation/benchmarks/mint/env.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from datatypes import ParseError, StepOutput, TaskState
|
||||
from tasks.base import Task
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
|
||||
|
||||
class SimplifiedEnv:
|
||||
INVALID_INPUT_MESSAGE = (
|
||||
"I don't understand your input. \n"
|
||||
'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n'
|
||||
'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n'
|
||||
'For example: The answer to the question is <solution> 42 </solution>. \n'
|
||||
)
|
||||
|
||||
def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]):
|
||||
self.agent_state = agent_state
|
||||
self.task = task
|
||||
|
||||
agent_action_count = {
|
||||
'propose_solution': 0,
|
||||
'use_tool': 0,
|
||||
'invalid_action': 0,
|
||||
}
|
||||
# check if agent_state has attribute turn_info set
|
||||
if hasattr(self.agent_state, 'propose_solution_count'):
|
||||
agent_action_count['propose_solution'] = (
|
||||
self.agent_state.propose_solution_count
|
||||
)
|
||||
|
||||
self.task_state = TaskState(agent_action_count=agent_action_count)
|
||||
|
||||
self.task_config = task_config
|
||||
|
||||
def step(self, lm_message: str):
|
||||
observation = self.handle_propose_solution(lm_message)
|
||||
|
||||
self.check_max_iteration()
|
||||
|
||||
turn_info = (
|
||||
self.task_config['max_iterations'] - self.agent_state.iteration,
|
||||
self.task_config['max_propose_solution']
|
||||
- self.task_state.agent_action_count['propose_solution'],
|
||||
)
|
||||
|
||||
output = StepOutput(
|
||||
observation=observation,
|
||||
success=self.task_state.success,
|
||||
turn_info=turn_info,
|
||||
)
|
||||
|
||||
self.agent_state.propose_solution_count = self.task_state.agent_action_count[
|
||||
'propose_solution'
|
||||
]
|
||||
self.log_output(output)
|
||||
return self.task_state
|
||||
|
||||
def handle_propose_solution(self, lm_message) -> str | None:
|
||||
"""Propose answer to check the task success.
|
||||
|
||||
It might set self.state.finished = True if the task is successful.
|
||||
"""
|
||||
self.task_state.agent_action_count['propose_solution'] += 1
|
||||
try:
|
||||
parsed = self.parse_propose_solution(lm_message)
|
||||
task_success = self.check_task_success(parsed['answer'])
|
||||
if task_success:
|
||||
self.task_state.finished = True
|
||||
self.task_state.success = True
|
||||
self.task_state.terminate_reason = 'task_success'
|
||||
# NOTE: should not return the function now, because we need to log the output
|
||||
# Set state.finished = True will terminate the episode
|
||||
except ParseError:
|
||||
return SimplifiedEnv.INVALID_INPUT_MESSAGE
|
||||
except Exception:
|
||||
error_traceback = traceback.format_exc()
|
||||
return f'{error_traceback}'
|
||||
|
||||
def parse_propose_solution(self, lm_message: str) -> dict:
|
||||
"""Define the parsing logic."""
|
||||
lm_output = '\n' + lm_message + '\n'
|
||||
|
||||
answer = '\n'.join(
|
||||
[
|
||||
i.strip()
|
||||
for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL)
|
||||
]
|
||||
)
|
||||
if answer == '':
|
||||
raise ParseError('No answer found.')
|
||||
|
||||
return {'answer': answer}
|
||||
|
||||
def log_output(self, output: StepOutput) -> None:
|
||||
if self.task_state.finished:
|
||||
return
|
||||
|
||||
content = output.to_str()
|
||||
self.task_state.latest_output = output.to_dict()
|
||||
self.task_state.latest_output['content'] = content
|
||||
|
||||
def check_task_success(self, answer: str) -> bool:
|
||||
# log_message.info(f"STUDENT ANSWER: [{answer}]")
|
||||
# log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]")
|
||||
return self.task.success(answer)
|
||||
|
||||
def check_max_iteration(self):
|
||||
"""Check if the agent has reached the max iteration limit.
|
||||
|
||||
It might set self.state.finished = True if the agent has reached the max iteration limit.
|
||||
"""
|
||||
if self.task_state.finished:
|
||||
# ignore if the episode is already finished (e.g., task success)
|
||||
return
|
||||
|
||||
if (
|
||||
# propose solution > max output solution
|
||||
self.task_state.agent_action_count['propose_solution']
|
||||
>= self.task_config['max_propose_solution']
|
||||
):
|
||||
self.task_state.finished = True
|
||||
self.task_state.success = False
|
||||
self.task_state.terminate_reason = 'max_propose_steps'
|
||||
elif self.agent_state.iteration >= self.task_config['max_iterations']:
|
||||
self.task_state.finished = True
|
||||
self.task_state.success = False
|
||||
self.task_state.terminate_reason = 'max_iterations'
|
||||
Reference in New Issue
Block a user