mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
3 Commits
github-tok
...
refactor-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2959abf4ba | ||
|
|
0125a5415f | ||
|
|
65de07299f |
193
openhands/runtime/impl/action_execution_client.py
Normal file
193
openhands/runtime/impl/action_execution_client.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Base class for runtimes that interact with the action execution server."""
|
||||
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import tenacity
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.exceptions import (
|
||||
AgentRuntimeDisconnectedError,
|
||||
AgentRuntimeError,
|
||||
AgentRuntimeNotFoundError,
|
||||
AgentRuntimeNotReadyError,
|
||||
AgentRuntimeTimeoutError,
|
||||
)
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.action import Action, ActionConfirmationStatus, FileEditAction
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
class ActionExecutionClient(Runtime):
|
||||
"""Base class for runtimes that interact with the action execution server.
|
||||
|
||||
This class contains shared logic between EventStreamRuntime and RemoteRuntime
|
||||
for interacting with the HTTP server defined in action_execution_server.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig,
|
||||
event_stream: EventStream,
|
||||
sid: str = "default",
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_callback: Any | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
self._runtime_initialized: bool = False
|
||||
self.api_url: str | None = None
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
is_retry: bool = True,
|
||||
**kwargs,
|
||||
) -> requests.Response:
|
||||
"""Send a request to the action execution server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: URL to send the request to
|
||||
is_retry: Whether to retry the request on failure
|
||||
**kwargs: Additional arguments to pass to requests.request()
|
||||
|
||||
Returns:
|
||||
Response from the server
|
||||
|
||||
Raises:
|
||||
AgentRuntimeError: If the request fails
|
||||
"""
|
||||
if not self._runtime_initialized and not url.endswith("/alive"):
|
||||
raise AgentRuntimeNotReadyError("Runtime client is not ready.")
|
||||
|
||||
if is_retry:
|
||||
retry_decorator = tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
|
||||
retry=tenacity.retry_if_exception_type(
|
||||
(ConnectionError, requests.exceptions.ConnectionError)
|
||||
),
|
||||
reraise=True,
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
return retry_decorator(send_request)(self.session, method, url, **kwargs)
|
||||
else:
|
||||
return send_request(self.session, method, url, **kwargs)
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action by sending it to the action execution server.
|
||||
|
||||
Args:
|
||||
action: Action to execute
|
||||
|
||||
Returns:
|
||||
Observation from executing the action
|
||||
"""
|
||||
if isinstance(action, FileEditAction):
|
||||
return self.edit(action)
|
||||
|
||||
# set timeout to default if not set
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
|
||||
with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation("")
|
||||
if (
|
||||
hasattr(action, 'confirmation_state')
|
||||
and action.confirmation_state
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return ErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
if (
|
||||
getattr(action, 'confirmation_state', None)
|
||||
== ActionConfirmationStatus.REJECTED
|
||||
):
|
||||
return UserRejectObservation(
|
||||
'Action has been rejected by the user! Waiting for further user input.'
|
||||
)
|
||||
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
request_body = {'action': event_to_dict(action)}
|
||||
self.log('debug', f'Request body: {request_body}')
|
||||
url = f"{self.runtime_url}/execute_action" if hasattr(self, "runtime_url") else f"{self.api_url}/execute_action"
|
||||
with self._send_request(
|
||||
'POST',
|
||||
url,
|
||||
json=request_body,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
) as response:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
except requests.Timeout:
|
||||
raise AgentRuntimeTimeoutError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
|
||||
def _wait_until_alive(self):
|
||||
"""Wait until the action execution server is alive and ready.
|
||||
|
||||
Raises:
|
||||
AgentRuntimeNotReadyError: If the server does not become ready in time
|
||||
AgentRuntimeDisconnectedError: If the connection is lost
|
||||
"""
|
||||
retry_decorator = tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
|
||||
retry=tenacity.retry_if_exception_type(
|
||||
(ConnectionError, requests.exceptions.ConnectionError)
|
||||
),
|
||||
reraise=True,
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
|
||||
try:
|
||||
with retry_decorator(send_request)(
|
||||
self.session,
|
||||
"GET",
|
||||
f"{self.api_url}/alive",
|
||||
timeout=5,
|
||||
):
|
||||
pass
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise AgentRuntimeDisconnectedError(
|
||||
f"Lost connection to runtime client: {e}"
|
||||
) from e
|
||||
@@ -1,7 +1,6 @@
|
||||
import atexit
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -23,6 +22,7 @@ from openhands.core.logger import DEBUG
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
@@ -32,8 +32,8 @@ from openhands.events.action import (
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
@@ -41,8 +41,8 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.impl.action_execution_client import ActionExecutionClient
|
||||
from openhands.runtime.impl.eventstream.containers import remove_all_containers
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
@@ -62,7 +62,7 @@ def remove_all_runtime_containers():
|
||||
_atexit_registered = False
|
||||
|
||||
|
||||
class EventStreamRuntime(Runtime):
|
||||
class EventStreamRuntime(ActionExecutionClient):
|
||||
"""This runtime will subscribe the event stream.
|
||||
When receive an event, it will send the event to runtime-client which run inside the docker environment.
|
||||
|
||||
@@ -74,30 +74,6 @@ class EventStreamRuntime(Runtime):
|
||||
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
|
||||
"""
|
||||
|
||||
# Need to provide this method to allow inheritors to init the Runtime
|
||||
# without initting the EventStreamRuntime.
|
||||
def init_base_runtime(
|
||||
self,
|
||||
config: AppConfig,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig,
|
||||
@@ -114,28 +90,7 @@ class EventStreamRuntime(Runtime):
|
||||
_atexit_registered = True
|
||||
atexit.register(remove_all_runtime_containers)
|
||||
|
||||
self.config = config
|
||||
self._host_port = 30000 # initial dummy value
|
||||
self._container_port = 30001 # initial dummy value
|
||||
self._vscode_url: str | None = None # initial dummy value
|
||||
self._runtime_initialized: bool = False
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
self.session = requests.Session()
|
||||
self.status_callback = status_callback
|
||||
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image = self.config.sandbox.runtime_container_image
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
self.container = None
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
|
||||
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
|
||||
|
||||
# Buffer for container logs
|
||||
self.log_streamer: LogStreamer | None = None
|
||||
|
||||
self.init_base_runtime(
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
@@ -146,6 +101,22 @@ class EventStreamRuntime(Runtime):
|
||||
headless_mode,
|
||||
)
|
||||
|
||||
self._host_port = 30000 # initial dummy value
|
||||
self._container_port = 30001 # initial dummy value
|
||||
self._vscode_url: str | None = None # initial dummy value
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image = self.config.sandbox.runtime_container_image
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
self.container = None
|
||||
|
||||
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
|
||||
|
||||
# Buffer for container logs
|
||||
self.log_streamer: LogStreamer | None = None
|
||||
|
||||
# Log runtime_extra_deps after base class initialization so self.sid is available
|
||||
if self.config.sandbox.runtime_extra_deps:
|
||||
self.log(
|
||||
@@ -407,59 +378,7 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
remove_all_containers(close_prefix)
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
if isinstance(action, FileEditAction):
|
||||
return self.edit(action)
|
||||
|
||||
# set timeout to default if not set
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
|
||||
with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
if (
|
||||
hasattr(action, 'confirmation_state')
|
||||
and action.confirmation_state
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return ErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
if (
|
||||
getattr(action, 'confirmation_state', None)
|
||||
== ActionConfirmationStatus.REJECTED
|
||||
):
|
||||
return UserRejectObservation(
|
||||
'Action has been rejected by the user! Waiting for further user input.'
|
||||
)
|
||||
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
with send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/execute_action',
|
||||
json={'action': event_to_dict(action)},
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
) as response:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
except requests.Timeout:
|
||||
raise AgentRuntimeTimeoutError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
|
||||
return obs
|
||||
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
@@ -131,7 +131,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}',
|
||||
)
|
||||
|
||||
self.init_base_runtime(
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -20,6 +19,7 @@ from openhands.core.exceptions import (
|
||||
)
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
@@ -28,28 +28,23 @@ from openhands.events.action import (
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||
from openhands.runtime.impl.action_execution_client import ActionExecutionClient
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.command import get_remote_startup_command
|
||||
from openhands.runtime.utils.request import (
|
||||
RequestHTTPError,
|
||||
send_request,
|
||||
)
|
||||
from openhands.runtime.utils.request import RequestHTTPError, send_request
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
class RemoteRuntime(Runtime):
|
||||
class RemoteRuntime(ActionExecutionClient):
|
||||
"""This runtime will connect to a remote oh-runtime-client."""
|
||||
|
||||
port: int = 60000 # default port for the remote runtime client
|
||||
@@ -65,10 +60,6 @@ class RemoteRuntime(Runtime):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
# We need to set session and action_semaphore before the __init__ below, or we get odd errors
|
||||
self.session = requests.Session()
|
||||
self.action_semaphore = threading.Semaphore(1)
|
||||
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
@@ -79,6 +70,7 @@ class RemoteRuntime(Runtime):
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
)
|
||||
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
'API key is required to use the remote runtime. '
|
||||
@@ -97,7 +89,6 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
self.runtime_id: str | None = None
|
||||
self.runtime_url: str | None = None
|
||||
self._runtime_initialized: bool = False
|
||||
self._vscode_url: str | None = None # initial dummy value
|
||||
|
||||
async def connect(self):
|
||||
@@ -407,45 +398,6 @@ class RemoteRuntime(Runtime):
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def run_action(self, action: Action, is_retry: bool = False) -> Observation:
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
if isinstance(action, FileEditAction):
|
||||
return self.edit(action)
|
||||
with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return ErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
request_body = {'action': event_to_dict(action)}
|
||||
self.log('debug', f'Request body: {request_body}')
|
||||
with self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/execute_action',
|
||||
is_retry=False,
|
||||
json=request_body,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
) as response:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
except requests.Timeout:
|
||||
raise AgentRuntimeTimeoutError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
return obs
|
||||
|
||||
def _send_request(self, method, url, is_retry=False, **kwargs):
|
||||
is_runtime_request = self.runtime_url and self.runtime_url in url
|
||||
try:
|
||||
@@ -490,6 +442,8 @@ class RemoteRuntime(Runtime):
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
|
||||
|
||||
def copy_to(
|
||||
self, host_src: str, sandbox_dest: str, recursive: bool = False
|
||||
) -> None:
|
||||
|
||||
@@ -101,7 +101,7 @@ class RunloopRuntime(EventStreamRuntime):
|
||||
self.session = requests.Session()
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
self.init_base_runtime(
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
|
||||
Reference in New Issue
Block a user