mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
2 Commits
auto/execu
...
openhands/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4c8a7d6fe | ||
|
|
52e329f5cd |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -258,4 +258,8 @@ containers/runtime/code
|
||||
# test results
|
||||
test-results
|
||||
.sessions
|
||||
|
||||
# ignore agent-sdk embedded repo if present
|
||||
agent-sdk/
|
||||
|
||||
.eval_sessions
|
||||
|
||||
6
openhands/acp/__init__.py
Normal file
6
openhands/acp/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .server import ACPAgentServer, run_stdio_server
|
||||
|
||||
__all__ = [
|
||||
'ACPAgentServer',
|
||||
'run_stdio_server',
|
||||
]
|
||||
11
openhands/acp/__main__.py
Normal file
11
openhands/acp/__main__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import asyncio
|
||||
|
||||
from .server import run_stdio_server
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asyncio.run(run_stdio_server())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
154
openhands/acp/jsonrpc.py
Normal file
154
openhands/acp/jsonrpc.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable
|
||||
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
id: int
|
||||
method: str
|
||||
params: Any | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
id: int
|
||||
result: Any | None = None
|
||||
error: Any | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
method: str
|
||||
params: Any | None
|
||||
|
||||
|
||||
class NDJsonStdio:
|
||||
"""Simple newline-delimited JSON over stdio.
|
||||
|
||||
This intentionally follows the ACP typescript ndJsonStream helper for simplicity.
|
||||
"""
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self._write_lock = asyncio.Lock()
|
||||
|
||||
async def write(self, obj: Any) -> None:
|
||||
data = json.dumps(obj, separators=(',', ':')) + '\n'
|
||||
async with self._write_lock:
|
||||
self.writer.write(data.encode('utf-8'))
|
||||
await self.writer.drain()
|
||||
|
||||
async def read(self) -> AsyncIterator[Any]:
|
||||
while not self.reader.at_eof():
|
||||
line = await self.reader.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
yield json.loads(line)
|
||||
except Exception:
|
||||
# ignore malformed lines
|
||||
continue
|
||||
|
||||
|
||||
class JsonRpcConnection:
|
||||
def __init__(self, stream: NDJsonStdio):
|
||||
self.stream = stream
|
||||
self._id = 0
|
||||
self._pending: dict[int, asyncio.Future[Any]] = {}
|
||||
self._closed = asyncio.Event()
|
||||
self._tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
async def send_request(self, method: str, params: Any | None = None) -> Any:
|
||||
self._id += 1
|
||||
req_id = self._id
|
||||
fut: asyncio.Future[Any] = asyncio.get_running_loop().create_future()
|
||||
self._pending[req_id] = fut
|
||||
await self.stream.write(
|
||||
{'jsonrpc': '2.0', 'id': req_id, 'method': method, 'params': params}
|
||||
)
|
||||
return await fut
|
||||
|
||||
async def send_notification(self, method: str, params: Any | None = None) -> None:
|
||||
await self.stream.write({'jsonrpc': '2.0', 'method': method, 'params': params})
|
||||
|
||||
async def send_response(
|
||||
self, id: int, result: Any | None = None, error: Any | None = None
|
||||
) -> None:
|
||||
if error is not None:
|
||||
await self.stream.write({'jsonrpc': '2.0', 'id': id, 'error': error})
|
||||
else:
|
||||
await self.stream.write({'jsonrpc': '2.0', 'id': id, 'result': result})
|
||||
|
||||
def _create_task(self, coro: Awaitable[Any]) -> None:
|
||||
task: asyncio.Task[Any] = asyncio.create_task(coro) # type: ignore[arg-type]
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard) # type: ignore[arg-type]
|
||||
|
||||
async def serve(
|
||||
self,
|
||||
on_request: Callable[[str, Any | None], Awaitable[Any | None]],
|
||||
on_notification: Callable[[str, Any | None], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
async for msg in self.stream.read():
|
||||
try:
|
||||
if not isinstance(msg, dict) or msg.get('jsonrpc') != '2.0':
|
||||
continue
|
||||
if 'method' in msg:
|
||||
method = msg['method']
|
||||
params = msg.get('params')
|
||||
if 'id' in msg:
|
||||
req_id = msg['id']
|
||||
|
||||
async def handle_req(
|
||||
method: str = method,
|
||||
params: Any | None = params,
|
||||
req_id: int = req_id,
|
||||
) -> None:
|
||||
try:
|
||||
result = await on_request(method, params)
|
||||
await self.send_response(
|
||||
req_id, result=result if result is not None else {}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
await self.send_response(
|
||||
req_id,
|
||||
error={'code': -32800, 'message': 'cancelled'},
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
await self.send_response(
|
||||
req_id, error={'code': -32603, 'message': str(e)}
|
||||
)
|
||||
|
||||
self._create_task(handle_req())
|
||||
else:
|
||||
if on_notification is not None:
|
||||
self._create_task(on_notification(method, params))
|
||||
elif 'id' in msg:
|
||||
fut = self._pending.pop(int(msg['id']), None)
|
||||
if fut:
|
||||
if 'result' in msg:
|
||||
fut.set_result(msg['result'])
|
||||
else:
|
||||
fut.set_exception(
|
||||
RuntimeError(msg.get('error') or 'unknown error')
|
||||
)
|
||||
except Exception:
|
||||
# ignore
|
||||
continue
|
||||
# Wait a brief moment for any straggling tasks
|
||||
if self._tasks:
|
||||
await asyncio.wait(self._tasks, timeout=1.0)
|
||||
self._closed.set()
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed.wait()
|
||||
166
openhands/acp/server.py
Normal file
166
openhands/acp/server.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from openhands.acp.jsonrpc import JsonRpcConnection, NDJsonStdio
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
PROTOCOL_VERSION = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
pending_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
class ACPAgentServer:
|
||||
"""Minimal ACP adapter to expose OpenHands as an ACP Agent over stdio NDJSON.
|
||||
|
||||
Implements initialize, session/new, session/prompt and session/cancel,
|
||||
and provides client-facing notifications session/update and requests
|
||||
like session/request_permission in the future. This is a minimal MVP to
|
||||
integrate with Zed ACP client.
|
||||
"""
|
||||
|
||||
def __init__(self, rpc: JsonRpcConnection):
|
||||
self.rpc = rpc
|
||||
self.sessions: dict[str, SessionState] = {}
|
||||
|
||||
async def handle_request(self, method: str, params: Any | None) -> Any | None:
|
||||
if method == 'initialize':
|
||||
return await self._initialize(params)
|
||||
if method == 'session/new':
|
||||
return await self._session_new(params)
|
||||
if method == 'session/prompt':
|
||||
return await self._session_prompt(params)
|
||||
if method == 'session/cancel':
|
||||
# Spec: cancel is a notification, but handle gracefully if sent as request
|
||||
await self._session_cancel(params)
|
||||
return {}
|
||||
if method == 'authenticate':
|
||||
# No-op for now
|
||||
return {}
|
||||
if method == 'session/set_mode':
|
||||
return {}
|
||||
raise RuntimeError(f'Method not implemented: {method}')
|
||||
|
||||
async def handle_notification(self, method: str, params: Any | None) -> None:
|
||||
if method == 'session/cancel':
|
||||
await self._session_cancel(params)
|
||||
|
||||
async def _initialize(self, params: dict[str, Any] | None) -> dict[str, Any]:
|
||||
return {
|
||||
'protocolVersion': PROTOCOL_VERSION,
|
||||
'agentCapabilities': {
|
||||
'loadSession': False,
|
||||
},
|
||||
'promptCapabilities': {
|
||||
'supportsImage': True,
|
||||
'supportsAudio': False,
|
||||
'supportsResources': True,
|
||||
},
|
||||
}
|
||||
|
||||
async def _session_new(self, params: dict[str, Any] | None) -> dict[str, Any]:
|
||||
# Client may provide preferred model or workspace details; ignore for MVP
|
||||
session_id = await self._generate_session_id()
|
||||
self.sessions[session_id] = SessionState()
|
||||
return {'sessionId': session_id}
|
||||
|
||||
async def _session_prompt(self, params: dict[str, Any] | None) -> dict[str, Any]:
|
||||
assert params is not None
|
||||
session_id = params.get('sessionId', '')
|
||||
# Accept either 'messages' (python test harness) or 'prompt' (ACP TS client)
|
||||
_messages = (
|
||||
params.get('messages') if 'messages' in params else params.get('prompt', [])
|
||||
)
|
||||
# For MVP we just echo a text agent message chunk and end_turn
|
||||
state = self.sessions.get(session_id)
|
||||
if state is None:
|
||||
raise RuntimeError(f'Unknown session {session_id}')
|
||||
|
||||
# cancel any pending prompt
|
||||
if state.pending_task and not state.pending_task.done():
|
||||
state.pending_task.cancel()
|
||||
try:
|
||||
await state.pending_task
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
async def run_turn() -> None:
|
||||
try:
|
||||
await self.rpc.send_notification(
|
||||
'session/update',
|
||||
{
|
||||
'sessionId': session_id,
|
||||
'update': {
|
||||
'sessionUpdate': 'agent_message_chunk',
|
||||
'content': {
|
||||
'type': 'text',
|
||||
'text': 'OpenHands is thinking...',
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
await self.rpc.send_notification(
|
||||
'session/update',
|
||||
{
|
||||
'sessionId': session_id,
|
||||
'update': {
|
||||
'sessionUpdate': 'agent_message_chunk',
|
||||
'content': {
|
||||
'type': 'text',
|
||||
'text': 'This is a minimal ACP adapter.',
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Send nothing more
|
||||
raise
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.exception('Error in prompt run: %s', e)
|
||||
|
||||
task = asyncio.create_task(run_turn())
|
||||
state.pending_task = task
|
||||
try:
|
||||
await task
|
||||
stop_reason = 'end_turn'
|
||||
except asyncio.CancelledError:
|
||||
stop_reason = 'cancelled'
|
||||
return {'stopReason': stop_reason}
|
||||
|
||||
async def _session_cancel(self, params: dict[str, Any] | None) -> None:
|
||||
if not params:
|
||||
return
|
||||
session_id = params.get('sessionId')
|
||||
if not session_id:
|
||||
return
|
||||
state = self.sessions.get(session_id)
|
||||
if state and state.pending_task and not state.pending_task.done():
|
||||
state.pending_task.cancel()
|
||||
|
||||
async def _generate_session_id(self) -> str:
|
||||
# Simple increasing counter based id
|
||||
return f'sess-{len(self.sessions) + 1:04d}'
|
||||
|
||||
|
||||
async def run_stdio_server() -> None:
|
||||
import sys
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
reader = asyncio.StreamReader()
|
||||
reader_protocol = asyncio.StreamReaderProtocol(reader)
|
||||
await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin)
|
||||
write_transport, write_protocol = await loop.connect_write_pipe(
|
||||
asyncio.streams.FlowControlMixin, sys.stdout
|
||||
)
|
||||
writer = asyncio.StreamWriter(write_transport, write_protocol, reader, loop)
|
||||
|
||||
stream = NDJsonStdio(reader, writer)
|
||||
rpc = JsonRpcConnection(stream)
|
||||
server = ACPAgentServer(rpc)
|
||||
await rpc.serve(server.handle_request, server.handle_notification)
|
||||
122
tests/unit/test_acp_minimal.py
Normal file
122
tests/unit/test_acp_minimal.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.acp.jsonrpc import JsonRpcConnection, NDJsonStdio
|
||||
from openhands.acp.server import ACPAgentServer
|
||||
|
||||
|
||||
class MemoryRW:
|
||||
def __init__(self):
|
||||
self.read_q: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self.write_q: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
|
||||
def get_streams(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def reader_gen(reader: asyncio.StreamReader):
|
||||
while True:
|
||||
data = await self.read_q.get()
|
||||
reader.feed_data(data)
|
||||
if data == b'':
|
||||
reader.feed_eof()
|
||||
break
|
||||
|
||||
async def make_reader():
|
||||
reader = asyncio.StreamReader()
|
||||
asyncio.create_task(reader_gen(reader))
|
||||
return reader
|
||||
|
||||
class DummyProto(asyncio.Protocol):
|
||||
async def _drain_helper(self) -> None: # satisfy StreamWriter.drain()
|
||||
return None
|
||||
|
||||
async def make_writer(reader: asyncio.StreamReader):
|
||||
class DummyTransport(asyncio.Transport):
|
||||
def write(inner_self, data: bytes) -> None:
|
||||
self.write_q.put_nowait(data)
|
||||
|
||||
def is_closing(inner_self) -> bool: # noqa: PLW3201
|
||||
return False
|
||||
|
||||
return asyncio.StreamWriter(DummyTransport(), DummyProto(), reader, loop)
|
||||
|
||||
return make_reader, make_writer
|
||||
|
||||
|
||||
async def rpc_pair():
|
||||
mem = MemoryRW()
|
||||
make_reader, make_writer = mem.get_streams()
|
||||
reader = await make_reader()
|
||||
writer = await make_writer(reader)
|
||||
|
||||
stream = NDJsonStdio(reader, writer)
|
||||
rpc = JsonRpcConnection(stream)
|
||||
server = ACPAgentServer(rpc)
|
||||
|
||||
async def serve():
|
||||
await rpc.serve(server.handle_request, server.handle_notification)
|
||||
|
||||
task = asyncio.create_task(serve())
|
||||
return mem, task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minimal_initialize_and_prompt():
|
||||
mem, task = await rpc_pair()
|
||||
|
||||
def encode(obj: Any) -> bytes:
|
||||
return (json.dumps(obj) + '\n').encode()
|
||||
|
||||
# send initialize request
|
||||
mem.read_q.put_nowait(
|
||||
encode({'jsonrpc': '2.0', 'id': 1, 'method': 'initialize', 'params': {}})
|
||||
)
|
||||
|
||||
# read initialize response
|
||||
data = await mem.write_q.get()
|
||||
msg = json.loads(data.decode())
|
||||
assert msg['id'] == 1
|
||||
assert 'result' in msg
|
||||
assert msg['result']['protocolVersion'] == 1
|
||||
|
||||
# new session
|
||||
mem.read_q.put_nowait(
|
||||
encode({'jsonrpc': '2.0', 'id': 2, 'method': 'session/new', 'params': {}})
|
||||
)
|
||||
msg = json.loads((await mem.write_q.get()).decode())
|
||||
assert msg['id'] == 2
|
||||
session_id = msg['result']['sessionId']
|
||||
|
||||
# prompt
|
||||
mem.read_q.put_nowait(
|
||||
encode(
|
||||
{
|
||||
'jsonrpc': '2.0',
|
||||
'id': 3,
|
||||
'method': 'session/prompt',
|
||||
'params': {'sessionId': session_id, 'messages': []},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Expect one or more session/update notifications before the result
|
||||
while True:
|
||||
msg = json.loads((await mem.write_q.get()).decode())
|
||||
if 'method' in msg:
|
||||
assert msg['method'] == 'session/update'
|
||||
assert msg['params']['sessionId'] == session_id
|
||||
continue
|
||||
# Then response to prompt
|
||||
assert msg['id'] == 3
|
||||
assert msg['result']['stopReason'] in ('end_turn', 'cancelled')
|
||||
break
|
||||
|
||||
# Close
|
||||
mem.read_q.put_nowait(b'')
|
||||
await asyncio.sleep(0) # let server finish
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
Reference in New Issue
Block a user