Compare commits

...

3 Commits

Author SHA1 Message Date
openhands
2959abf4ba Add missing imports to ActionExecutionClient 2024-12-25 15:54:31 +00:00
openhands
0125a5415f Move all run_action logic to ActionExecutionClient 2024-12-25 15:52:08 +00:00
openhands
65de07299f Refactor runtime action execution
- Create ActionExecutionClient base class for shared HTTP server interaction logic
- Update EventStreamRuntime and RemoteRuntime to inherit from ActionExecutionClient
- Remove duplicate code and clean up imports
- Update ModalRuntime and RunloopRuntime to use super().__init__()
2024-12-25 15:47:02 +00:00
5 changed files with 223 additions and 157 deletions

View File

@@ -0,0 +1,193 @@
"""Base class for runtimes that interact with the action execution server."""
import threading
from typing import Any
import requests
import tenacity
from openhands.core.config import AppConfig
from openhands.core.exceptions import (
AgentRuntimeDisconnectedError,
AgentRuntimeError,
AgentRuntimeNotFoundError,
AgentRuntimeNotReadyError,
AgentRuntimeTimeoutError,
)
from openhands.events import EventStream
from openhands.events.action import Action, ActionConfirmationStatus, FileEditAction
from openhands.events.observation import (
ErrorObservation,
NullObservation,
Observation,
UserRejectObservation,
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request
from openhands.utils.tenacity_stop import stop_if_should_exit
class ActionExecutionClient(Runtime):
"""Base class for runtimes that interact with the action execution server.
This class contains shared logic between EventStreamRuntime and RemoteRuntime
for interacting with the HTTP server defined in action_execution_server.py.
"""
def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = "default",
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_callback: Any | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
):
super().__init__(
config,
event_stream,
sid,
plugins,
env_vars,
status_callback,
attach_to_existing,
headless_mode,
)
self.session = requests.Session()
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self._runtime_initialized: bool = False
self.api_url: str | None = None
def _send_request(
self,
method: str,
url: str,
is_retry: bool = True,
**kwargs,
) -> requests.Response:
"""Send a request to the action execution server.
Args:
method: HTTP method (GET, POST, etc.)
url: URL to send the request to
is_retry: Whether to retry the request on failure
**kwargs: Additional arguments to pass to requests.request()
Returns:
Response from the server
Raises:
AgentRuntimeError: If the request fails
"""
if not self._runtime_initialized and not url.endswith("/alive"):
raise AgentRuntimeNotReadyError("Runtime client is not ready.")
if is_retry:
retry_decorator = tenacity.retry(
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
retry=tenacity.retry_if_exception_type(
(ConnectionError, requests.exceptions.ConnectionError)
),
reraise=True,
wait=tenacity.wait_fixed(2),
)
return retry_decorator(send_request)(self.session, method, url, **kwargs)
else:
return send_request(self.session, method, url, **kwargs)
def run_action(self, action: Action) -> Observation:
"""Run an action by sending it to the action execution server.
Args:
action: Action to execute
Returns:
Observation from executing the action
"""
if isinstance(action, FileEditAction):
return self.edit(action)
# set timeout to default if not set
if action.timeout is None:
action.timeout = self.config.sandbox.timeout
with self.action_semaphore:
if not action.runnable:
return NullObservation("")
if (
hasattr(action, 'confirmation_state')
and action.confirmation_state
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
raise ValueError(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.',
error_id='AGENT_ERROR$BAD_ACTION',
)
if (
getattr(action, 'confirmation_state', None)
== ActionConfirmationStatus.REJECTED
):
return UserRejectObservation(
'Action has been rejected by the user! Waiting for further user input.'
)
assert action.timeout is not None
try:
request_body = {'action': event_to_dict(action)}
self.log('debug', f'Request body: {request_body}')
url = f"{self.runtime_url}/execute_action" if hasattr(self, "runtime_url") else f"{self.api_url}/execute_action"
with self._send_request(
'POST',
url,
json=request_body,
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
) as response:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
return obs
except requests.Timeout:
raise AgentRuntimeTimeoutError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
)
def _wait_until_alive(self):
"""Wait until the action execution server is alive and ready.
Raises:
AgentRuntimeNotReadyError: If the server does not become ready in time
AgentRuntimeDisconnectedError: If the connection is lost
"""
retry_decorator = tenacity.retry(
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
retry=tenacity.retry_if_exception_type(
(ConnectionError, requests.exceptions.ConnectionError)
),
reraise=True,
wait=tenacity.wait_fixed(2),
)
try:
with retry_decorator(send_request)(
self.session,
"GET",
f"{self.api_url}/alive",
timeout=5,
):
pass
except requests.exceptions.ConnectionError as e:
raise AgentRuntimeDisconnectedError(
f"Lost connection to runtime client: {e}"
) from e

View File

@@ -1,7 +1,6 @@
import atexit
import os
import tempfile
import threading
from functools import lru_cache
from pathlib import Path
from typing import Callable
@@ -23,6 +22,7 @@ from openhands.core.logger import DEBUG
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.action import (
Action,
ActionConfirmationStatus,
BrowseInteractiveAction,
BrowseURLAction,
@@ -32,8 +32,8 @@ from openhands.events.action import (
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.action.action import Action
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
NullObservation,
Observation,
@@ -41,8 +41,8 @@ from openhands.events.observation import (
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.builder import DockerRuntimeBuilder
from openhands.runtime.impl.action_execution_client import ActionExecutionClient
from openhands.runtime.impl.eventstream.containers import remove_all_containers
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils import find_available_tcp_port
@@ -62,7 +62,7 @@ def remove_all_runtime_containers():
_atexit_registered = False
class EventStreamRuntime(Runtime):
class EventStreamRuntime(ActionExecutionClient):
"""This runtime will subscribe the event stream.
When receive an event, it will send the event to runtime-client which run inside the docker environment.
@@ -74,30 +74,6 @@ class EventStreamRuntime(Runtime):
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""
# Need to provide this method to allow inheritors to init the Runtime
# without initting the EventStreamRuntime.
def init_base_runtime(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
):
super().__init__(
config,
event_stream,
sid,
plugins,
env_vars,
status_callback,
attach_to_existing,
headless_mode,
)
def __init__(
self,
config: AppConfig,
@@ -114,28 +90,7 @@ class EventStreamRuntime(Runtime):
_atexit_registered = True
atexit.register(remove_all_runtime_containers)
self.config = config
self._host_port = 30000 # initial dummy value
self._container_port = 30001 # initial dummy value
self._vscode_url: str | None = None # initial dummy value
self._runtime_initialized: bool = False
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
self.session = requests.Session()
self.status_callback = status_callback
self.docker_client: docker.DockerClient = self._init_docker_client()
self.base_container_image = self.config.sandbox.base_container_image
self.runtime_container_image = self.config.sandbox.runtime_container_image
self.container_name = CONTAINER_NAME_PREFIX + sid
self.container = None
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
# Buffer for container logs
self.log_streamer: LogStreamer | None = None
self.init_base_runtime(
super().__init__(
config,
event_stream,
sid,
@@ -146,6 +101,22 @@ class EventStreamRuntime(Runtime):
headless_mode,
)
self._host_port = 30000 # initial dummy value
self._container_port = 30001 # initial dummy value
self._vscode_url: str | None = None # initial dummy value
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
self.docker_client: docker.DockerClient = self._init_docker_client()
self.base_container_image = self.config.sandbox.base_container_image
self.runtime_container_image = self.config.sandbox.runtime_container_image
self.container_name = CONTAINER_NAME_PREFIX + sid
self.container = None
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
# Buffer for container logs
self.log_streamer: LogStreamer | None = None
# Log runtime_extra_deps after base class initialization so self.sid is available
if self.config.sandbox.runtime_extra_deps:
self.log(
@@ -407,59 +378,7 @@ class EventStreamRuntime(Runtime):
)
remove_all_containers(close_prefix)
def run_action(self, action: Action) -> Observation:
if isinstance(action, FileEditAction):
return self.edit(action)
# set timeout to default if not set
if action.timeout is None:
action.timeout = self.config.sandbox.timeout
with self.action_semaphore:
if not action.runnable:
return NullObservation('')
if (
hasattr(action, 'confirmation_state')
and action.confirmation_state
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
raise ValueError(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.',
error_id='AGENT_ERROR$BAD_ACTION',
)
if (
getattr(action, 'confirmation_state', None)
== ActionConfirmationStatus.REJECTED
):
return UserRejectObservation(
'Action has been rejected by the user! Waiting for further user input.'
)
assert action.timeout is not None
try:
with send_request(
self.session,
'POST',
f'{self.api_url}/execute_action',
json={'action': event_to_dict(action)},
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
) as response:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
raise AgentRuntimeTimeoutError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
)
return obs
def run(self, action: CmdRunAction) -> Observation:
return self.run_action(action)

View File

@@ -131,7 +131,7 @@ class ModalRuntime(EventStreamRuntime):
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}',
)
self.init_base_runtime(
super().__init__(
config,
event_stream,
sid,

View File

@@ -1,6 +1,5 @@
import os
import tempfile
import threading
from pathlib import Path
from typing import Callable, Optional
from urllib.parse import urlparse
@@ -20,6 +19,7 @@ from openhands.core.exceptions import (
)
from openhands.events import EventStream
from openhands.events.action import (
Action,
BrowseInteractiveAction,
BrowseURLAction,
CmdRunAction,
@@ -28,28 +28,23 @@ from openhands.events.action import (
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.action.action import Action
from openhands.events.observation import (
ErrorObservation,
NullObservation,
Observation,
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.impl.action_execution_client import ActionExecutionClient
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.command import get_remote_startup_command
from openhands.runtime.utils.request import (
RequestHTTPError,
send_request,
)
from openhands.runtime.utils.request import RequestHTTPError, send_request
from openhands.runtime.utils.runtime_build import build_runtime_image
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.tenacity_stop import stop_if_should_exit
class RemoteRuntime(Runtime):
class RemoteRuntime(ActionExecutionClient):
"""This runtime will connect to a remote oh-runtime-client."""
port: int = 60000 # default port for the remote runtime client
@@ -65,10 +60,6 @@ class RemoteRuntime(Runtime):
attach_to_existing: bool = False,
headless_mode: bool = True,
):
# We need to set session and action_semaphore before the __init__ below, or we get odd errors
self.session = requests.Session()
self.action_semaphore = threading.Semaphore(1)
super().__init__(
config,
event_stream,
@@ -79,6 +70,7 @@ class RemoteRuntime(Runtime):
attach_to_existing,
headless_mode,
)
if self.config.sandbox.api_key is None:
raise ValueError(
'API key is required to use the remote runtime. '
@@ -97,7 +89,6 @@ class RemoteRuntime(Runtime):
)
self.runtime_id: str | None = None
self.runtime_url: str | None = None
self._runtime_initialized: bool = False
self._vscode_url: str | None = None # initial dummy value
async def connect(self):
@@ -407,45 +398,6 @@ class RemoteRuntime(Runtime):
finally:
self.session.close()
def run_action(self, action: Action, is_retry: bool = False) -> Observation:
if action.timeout is None:
action.timeout = self.config.sandbox.timeout
if isinstance(action, FileEditAction):
return self.edit(action)
with self.action_semaphore:
if not action.runnable:
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
raise ValueError(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.',
error_id='AGENT_ERROR$BAD_ACTION',
)
assert action.timeout is not None
try:
request_body = {'action': event_to_dict(action)}
self.log('debug', f'Request body: {request_body}')
with self._send_request(
'POST',
f'{self.runtime_url}/execute_action',
is_retry=False,
json=request_body,
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
) as response:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
raise AgentRuntimeTimeoutError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
)
return obs
def _send_request(self, method, url, is_retry=False, **kwargs):
is_runtime_request = self.runtime_url and self.runtime_url in url
try:
@@ -490,6 +442,8 @@ class RemoteRuntime(Runtime):
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
return self.run_action(action)
def copy_to(
self, host_src: str, sandbox_dest: str, recursive: bool = False
) -> None:

View File

@@ -101,7 +101,7 @@ class RunloopRuntime(EventStreamRuntime):
self.session = requests.Session()
self.container_name = CONTAINER_NAME_PREFIX + sid
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self.init_base_runtime(
super().__init__(
config,
event_stream,
sid,