Re-adjust ssh_box for parallel evaluation (#1729)

* update ssh_box

* fix controller in test

---------

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
This commit is contained in:
Xingyao Wang
2024-05-13 14:35:30 +08:00
committed by GitHub
parent ba8d8634ac
commit 00c0edae5f

View File

@@ -3,6 +3,7 @@ import json
import os
import sys
import tarfile
import tempfile
import time
import uuid
from collections import namedtuple
@@ -172,8 +173,19 @@ class DockerSSHBox(Sandbox):
self._ssh_port = find_available_tcp_port()
# always restart the container, cuz the initial be regarded as a new session
self.restart_docker_container()
n_tries = 5
while n_tries > 0:
try:
self.restart_docker_container()
break
except Exception as e:
logger.exception(
'Failed to start Docker container, retrying...', exc_info=False
)
n_tries -= 1
if n_tries == 0:
raise e
time.sleep(5)
self.setup_user()
self.start_ssh_session()
# make sure /tmp always exists
@@ -313,6 +325,24 @@ class DockerSSHBox(Sandbox):
bg_cmd = self.background_commands[id]
return bg_cmd.read_logs()
def _send_interrupt(
self,
cmd: str,
prev_output: str = '',
ignore_last_output: bool = False,
) -> tuple[int, str]:
logger.exception('Command timed out, killing process...', exc_info=False)
# send a SIGINT to the process
self.ssh.sendintr()
self.ssh.prompt()
command_output = prev_output
if not ignore_last_output:
command_output += '\n' + self.ssh.before.decode('utf-8')
return (
-1,
f'Command: "{cmd}" timed out. Sending SIGINT to the process: {command_output}',
)
def execute(self, cmd: str) -> Tuple[int, str]:
commands = split_bash_commands(cmd)
if len(commands) > 1:
@@ -329,14 +359,7 @@ class DockerSSHBox(Sandbox):
success = self.ssh.prompt(timeout=self.timeout)
if not success:
logger.exception('Command timed out, killing process...', exc_info=False)
# send a SIGINT to the process
self.ssh.sendintr()
self.ssh.prompt()
command_output = self.ssh.before.decode('utf-8')
return (
-1,
f'Command: "{cmd}" timed out. Sending SIGINT to the process: {command_output}',
)
return self._send_interrupt(cmd)
command_output = self.ssh.before.decode('utf-8')
# once out, make sure that we have *every* output, we while loop until we get an empty output
@@ -361,10 +384,15 @@ class DockerSSHBox(Sandbox):
self.ssh.sendline('echo $?')
self.ssh.prompt()
exit_code_str = self.ssh.before.decode('utf-8')
_start_time = time.time()
while not exit_code_str:
self.ssh.prompt()
exit_code_str = self.ssh.before.decode('utf-8')
logger.debug(f'WAITING FOR exit code: {exit_code_str}')
if time.time() - _start_time > self.timeout:
return self._send_interrupt(
cmd, command_output, ignore_last_output=True
)
exit_code = int(exit_code_str.strip())
return exit_code, command_output
@@ -380,32 +408,34 @@ class DockerSSHBox(Sandbox):
f'Failed to create directory {sandbox_dest} in sandbox: {logs}'
)
if recursive:
assert os.path.isdir(
host_src
), 'Source must be a directory when recursive is True'
files = glob(host_src + '/**/*', recursive=True)
srcname = os.path.basename(host_src)
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
with tarfile.open(tar_filename, mode='w') as tar:
for file in files:
tar.add(
file, arcname=os.path.relpath(file, os.path.dirname(host_src))
)
else:
assert os.path.isfile(
host_src
), 'Source must be a file when recursive is False'
srcname = os.path.basename(host_src)
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
with tarfile.open(tar_filename, mode='w') as tar:
tar.add(host_src, arcname=srcname)
# use temp directory to store the tar file to avoid
# conflict of filename when running multi-processes
with tempfile.TemporaryDirectory() as tmp_dir:
if recursive:
assert os.path.isdir(
host_src
), 'Source must be a directory when recursive is True'
files = glob(host_src + '/**/*', recursive=True)
srcname = os.path.basename(host_src)
tar_filename = os.path.join(tmp_dir, srcname + '.tar')
with tarfile.open(tar_filename, mode='w') as tar:
for file in files:
tar.add(
file,
arcname=os.path.relpath(file, os.path.dirname(host_src)),
)
else:
assert os.path.isfile(
host_src
), 'Source must be a file when recursive is False'
srcname = os.path.basename(host_src)
tar_filename = os.path.join(tmp_dir, srcname + '.tar')
with tarfile.open(tar_filename, mode='w') as tar:
tar.add(host_src, arcname=srcname)
with open(tar_filename, 'rb') as f:
data = f.read()
self.container.put_archive(os.path.dirname(sandbox_dest), data)
os.remove(tar_filename)
with open(tar_filename, 'rb') as f:
data = f.read()
self.container.put_archive(os.path.dirname(sandbox_dest), data)
def execute_in_background(self, cmd: str) -> Process:
result = self.container.exec_run(
@@ -502,6 +532,21 @@ class DockerSSHBox(Sandbox):
except docker.errors.NotFound:
return False
@property
def volumes(self):
mount_dir = config.workspace_mount_path
logger.info(f'Mounting workspace directory: {mount_dir}')
return {
mount_dir: {'bind': self.sandbox_workspace_dir, 'mode': 'rw'},
# mount cache directory to /home/opendevin/.cache for pip cache reuse
config.cache_dir: {
'bind': (
'/home/opendevin/.cache' if self.run_as_devin else '/root/.cache'
),
'mode': 'rw',
},
}
def restart_docker_container(self):
try:
self.stop_docker_container()
@@ -525,9 +570,8 @@ class DockerSSHBox(Sandbox):
)
)
mount_dir = config.workspace_mount_path
logger.info(f'Mounting workspace directory: {mount_dir}')
# start the container
logger.info(f'Mounting volumes: {self.volumes}')
self.container = self.docker_client.containers.run(
self.container_image,
# allow root login
@@ -536,18 +580,7 @@ class DockerSSHBox(Sandbox):
working_dir=self.sandbox_workspace_dir,
name=self.container_name,
detach=True,
volumes={
mount_dir: {'bind': self.sandbox_workspace_dir, 'mode': 'rw'},
# mount cache directory to /home/opendevin/.cache for pip cache reuse
config.cache_dir: {
'bind': (
'/home/opendevin/.cache'
if self.run_as_devin
else '/root/.cache'
),
'mode': 'rw',
},
},
volumes=self.volumes,
)
logger.info('Container started')
except Exception as ex:
@@ -578,7 +611,10 @@ class DockerSSHBox(Sandbox):
containers = self.docker_client.containers.list(all=True)
for container in containers:
try:
if container.name.startswith(self.container_name_prefix):
if container.name.startswith(self.container_name):
# only remove the container we created
# otherwise all other containers with the same prefix will be removed
# which will mess up with parallel evaluation
container.remove(force=True)
except docker.errors.NotFound:
pass