mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 23:38:08 -05:00
Re-adjust ssh_box for parallel evaluation (#1729)
* update ssh_box * fix controller in test --------- Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
@@ -172,8 +173,19 @@ class DockerSSHBox(Sandbox):
|
||||
self._ssh_port = find_available_tcp_port()
|
||||
|
||||
# always restart the container, cuz the initial be regarded as a new session
|
||||
self.restart_docker_container()
|
||||
|
||||
n_tries = 5
|
||||
while n_tries > 0:
|
||||
try:
|
||||
self.restart_docker_container()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to start Docker container, retrying...', exc_info=False
|
||||
)
|
||||
n_tries -= 1
|
||||
if n_tries == 0:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
self.setup_user()
|
||||
self.start_ssh_session()
|
||||
# make sure /tmp always exists
|
||||
@@ -313,6 +325,24 @@ class DockerSSHBox(Sandbox):
|
||||
bg_cmd = self.background_commands[id]
|
||||
return bg_cmd.read_logs()
|
||||
|
||||
def _send_interrupt(
|
||||
self,
|
||||
cmd: str,
|
||||
prev_output: str = '',
|
||||
ignore_last_output: bool = False,
|
||||
) -> tuple[int, str]:
|
||||
logger.exception('Command timed out, killing process...', exc_info=False)
|
||||
# send a SIGINT to the process
|
||||
self.ssh.sendintr()
|
||||
self.ssh.prompt()
|
||||
command_output = prev_output
|
||||
if not ignore_last_output:
|
||||
command_output += '\n' + self.ssh.before.decode('utf-8')
|
||||
return (
|
||||
-1,
|
||||
f'Command: "{cmd}" timed out. Sending SIGINT to the process: {command_output}',
|
||||
)
|
||||
|
||||
def execute(self, cmd: str) -> Tuple[int, str]:
|
||||
commands = split_bash_commands(cmd)
|
||||
if len(commands) > 1:
|
||||
@@ -329,14 +359,7 @@ class DockerSSHBox(Sandbox):
|
||||
success = self.ssh.prompt(timeout=self.timeout)
|
||||
if not success:
|
||||
logger.exception('Command timed out, killing process...', exc_info=False)
|
||||
# send a SIGINT to the process
|
||||
self.ssh.sendintr()
|
||||
self.ssh.prompt()
|
||||
command_output = self.ssh.before.decode('utf-8')
|
||||
return (
|
||||
-1,
|
||||
f'Command: "{cmd}" timed out. Sending SIGINT to the process: {command_output}',
|
||||
)
|
||||
return self._send_interrupt(cmd)
|
||||
command_output = self.ssh.before.decode('utf-8')
|
||||
|
||||
# once out, make sure that we have *every* output, we while loop until we get an empty output
|
||||
@@ -361,10 +384,15 @@ class DockerSSHBox(Sandbox):
|
||||
self.ssh.sendline('echo $?')
|
||||
self.ssh.prompt()
|
||||
exit_code_str = self.ssh.before.decode('utf-8')
|
||||
_start_time = time.time()
|
||||
while not exit_code_str:
|
||||
self.ssh.prompt()
|
||||
exit_code_str = self.ssh.before.decode('utf-8')
|
||||
logger.debug(f'WAITING FOR exit code: {exit_code_str}')
|
||||
if time.time() - _start_time > self.timeout:
|
||||
return self._send_interrupt(
|
||||
cmd, command_output, ignore_last_output=True
|
||||
)
|
||||
exit_code = int(exit_code_str.strip())
|
||||
return exit_code, command_output
|
||||
|
||||
@@ -380,32 +408,34 @@ class DockerSSHBox(Sandbox):
|
||||
f'Failed to create directory {sandbox_dest} in sandbox: {logs}'
|
||||
)
|
||||
|
||||
if recursive:
|
||||
assert os.path.isdir(
|
||||
host_src
|
||||
), 'Source must be a directory when recursive is True'
|
||||
files = glob(host_src + '/**/*', recursive=True)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
for file in files:
|
||||
tar.add(
|
||||
file, arcname=os.path.relpath(file, os.path.dirname(host_src))
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(
|
||||
host_src
|
||||
), 'Source must be a file when recursive is False'
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
tar.add(host_src, arcname=srcname)
|
||||
# use temp directory to store the tar file to avoid
|
||||
# conflict of filename when running multi-processes
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
if recursive:
|
||||
assert os.path.isdir(
|
||||
host_src
|
||||
), 'Source must be a directory when recursive is True'
|
||||
files = glob(host_src + '/**/*', recursive=True)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(tmp_dir, srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
for file in files:
|
||||
tar.add(
|
||||
file,
|
||||
arcname=os.path.relpath(file, os.path.dirname(host_src)),
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(
|
||||
host_src
|
||||
), 'Source must be a file when recursive is False'
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(tmp_dir, srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
tar.add(host_src, arcname=srcname)
|
||||
|
||||
with open(tar_filename, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
self.container.put_archive(os.path.dirname(sandbox_dest), data)
|
||||
os.remove(tar_filename)
|
||||
with open(tar_filename, 'rb') as f:
|
||||
data = f.read()
|
||||
self.container.put_archive(os.path.dirname(sandbox_dest), data)
|
||||
|
||||
def execute_in_background(self, cmd: str) -> Process:
|
||||
result = self.container.exec_run(
|
||||
@@ -502,6 +532,21 @@ class DockerSSHBox(Sandbox):
|
||||
except docker.errors.NotFound:
|
||||
return False
|
||||
|
||||
@property
|
||||
def volumes(self):
|
||||
mount_dir = config.workspace_mount_path
|
||||
logger.info(f'Mounting workspace directory: {mount_dir}')
|
||||
return {
|
||||
mount_dir: {'bind': self.sandbox_workspace_dir, 'mode': 'rw'},
|
||||
# mount cache directory to /home/opendevin/.cache for pip cache reuse
|
||||
config.cache_dir: {
|
||||
'bind': (
|
||||
'/home/opendevin/.cache' if self.run_as_devin else '/root/.cache'
|
||||
),
|
||||
'mode': 'rw',
|
||||
},
|
||||
}
|
||||
|
||||
def restart_docker_container(self):
|
||||
try:
|
||||
self.stop_docker_container()
|
||||
@@ -525,9 +570,8 @@ class DockerSSHBox(Sandbox):
|
||||
)
|
||||
)
|
||||
|
||||
mount_dir = config.workspace_mount_path
|
||||
logger.info(f'Mounting workspace directory: {mount_dir}')
|
||||
# start the container
|
||||
logger.info(f'Mounting volumes: {self.volumes}')
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.container_image,
|
||||
# allow root login
|
||||
@@ -536,18 +580,7 @@ class DockerSSHBox(Sandbox):
|
||||
working_dir=self.sandbox_workspace_dir,
|
||||
name=self.container_name,
|
||||
detach=True,
|
||||
volumes={
|
||||
mount_dir: {'bind': self.sandbox_workspace_dir, 'mode': 'rw'},
|
||||
# mount cache directory to /home/opendevin/.cache for pip cache reuse
|
||||
config.cache_dir: {
|
||||
'bind': (
|
||||
'/home/opendevin/.cache'
|
||||
if self.run_as_devin
|
||||
else '/root/.cache'
|
||||
),
|
||||
'mode': 'rw',
|
||||
},
|
||||
},
|
||||
volumes=self.volumes,
|
||||
)
|
||||
logger.info('Container started')
|
||||
except Exception as ex:
|
||||
@@ -578,7 +611,10 @@ class DockerSSHBox(Sandbox):
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
try:
|
||||
if container.name.startswith(self.container_name_prefix):
|
||||
if container.name.startswith(self.container_name):
|
||||
# only remove the container we created
|
||||
# otherwise all other containers with the same prefix will be removed
|
||||
# which will mess up with parallel evaluation
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user