mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
33 Commits
multi-swe-
...
feature/pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75a1fad77e | ||
|
|
030d934621 | ||
|
|
b1ac189aaa | ||
|
|
0fa92ccfe4 | ||
|
|
8ddad5a52c | ||
|
|
8f1182135f | ||
|
|
ce8d857690 | ||
|
|
84b2b5a062 | ||
|
|
4076445a7a | ||
|
|
e58d6a9e35 | ||
|
|
7d5e64507c | ||
|
|
f3ef5e84dc | ||
|
|
41d4cb5d29 | ||
|
|
c06772fbc6 | ||
|
|
4f8baf3698 | ||
|
|
aa5e9f792c | ||
|
|
a0c4d5217b | ||
|
|
5aeeaca0f0 | ||
|
|
ba014c957e | ||
|
|
6c67517f56 | ||
|
|
2825bb6dc3 | ||
|
|
ea3076364f | ||
|
|
f6245b9a99 | ||
|
|
2e6fa13550 | ||
|
|
315d586b14 | ||
|
|
3774a459df | ||
|
|
4fde183c0b | ||
|
|
95e60953f1 | ||
|
|
aab80f2975 | ||
|
|
83783c44b3 | ||
|
|
207d628817 | ||
|
|
f51ecec3e7 | ||
|
|
b89f4c1748 |
@@ -5,24 +5,23 @@
|
||||
* Mock Service Worker.
|
||||
* @see https://github.com/mswjs/msw
|
||||
* - Please do NOT modify this file.
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.8.4'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const PACKAGE_VERSION = '2.10.2'
|
||||
const INTEGRITY_CHECKSUM = 'f5825c521429caf22a4dd13b66e243af'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
self.addEventListener('install', function () {
|
||||
addEventListener('install', function () {
|
||||
self.skipWaiting()
|
||||
})
|
||||
|
||||
self.addEventListener('activate', function (event) {
|
||||
addEventListener('activate', function (event) {
|
||||
event.waitUntil(self.clients.claim())
|
||||
})
|
||||
|
||||
self.addEventListener('message', async function (event) {
|
||||
const clientId = event.source.id
|
||||
addEventListener('message', async function (event) {
|
||||
const clientId = Reflect.get(event.source || {}, 'id')
|
||||
|
||||
if (!clientId || !self.clients) {
|
||||
return
|
||||
@@ -94,17 +93,18 @@ self.addEventListener('message', async function (event) {
|
||||
}
|
||||
})
|
||||
|
||||
self.addEventListener('fetch', function (event) {
|
||||
const { request } = event
|
||||
|
||||
addEventListener('fetch', function (event) {
|
||||
// Bypass navigation requests.
|
||||
if (request.mode === 'navigate') {
|
||||
if (event.request.mode === 'navigate') {
|
||||
return
|
||||
}
|
||||
|
||||
// Opening the DevTools triggers the "only-if-cached" request
|
||||
// that cannot be handled by the worker. Bypass such requests.
|
||||
if (request.cache === 'only-if-cached' && request.mode !== 'same-origin') {
|
||||
if (
|
||||
event.request.cache === 'only-if-cached' &&
|
||||
event.request.mode !== 'same-origin'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -115,48 +115,62 @@ self.addEventListener('fetch', function (event) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate unique request ID.
|
||||
const requestId = crypto.randomUUID()
|
||||
event.respondWith(handleRequest(event, requestId))
|
||||
})
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {string} requestId
|
||||
*/
|
||||
async function handleRequest(event, requestId) {
|
||||
const client = await resolveMainClient(event)
|
||||
const requestCloneForEvents = event.request.clone()
|
||||
const response = await getResponse(event, client, requestId)
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
// this message will pend indefinitely.
|
||||
if (client && activeClientIds.has(client.id)) {
|
||||
;(async function () {
|
||||
const responseClone = response.clone()
|
||||
const serializedRequest = await serializeRequest(requestCloneForEvents)
|
||||
|
||||
sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
requestId,
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
// Clone the response so both the client and the library could consume it.
|
||||
const responseClone = response.clone()
|
||||
|
||||
sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
request: {
|
||||
id: requestId,
|
||||
...serializedRequest,
|
||||
},
|
||||
response: {
|
||||
type: responseClone.type,
|
||||
status: responseClone.status,
|
||||
statusText: responseClone.statusText,
|
||||
body: responseClone.body,
|
||||
headers: Object.fromEntries(responseClone.headers.entries()),
|
||||
body: responseClone.body,
|
||||
},
|
||||
},
|
||||
[responseClone.body],
|
||||
)
|
||||
})()
|
||||
},
|
||||
responseClone.body ? [serializedRequest.body, responseClone.body] : [],
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// Resolve the main client for the given event.
|
||||
// Client that issues a request doesn't necessarily equal the client
|
||||
// that registered the worker. It's with the latter the worker should
|
||||
// communicate with during the response resolving phase.
|
||||
/**
|
||||
* Resolve the main client for the given event.
|
||||
* Client that issues a request doesn't necessarily equal the client
|
||||
* that registered the worker. It's with the latter the worker should
|
||||
* communicate with during the response resolving phase.
|
||||
* @param {FetchEvent} event
|
||||
* @returns {Promise<Client | undefined>}
|
||||
*/
|
||||
async function resolveMainClient(event) {
|
||||
const client = await self.clients.get(event.clientId)
|
||||
|
||||
@@ -184,12 +198,16 @@ async function resolveMainClient(event) {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {Client | undefined} client
|
||||
* @param {string} requestId
|
||||
* @returns {Promise<Response>}
|
||||
*/
|
||||
async function getResponse(event, client, requestId) {
|
||||
const { request } = event
|
||||
|
||||
// Clone the request because it might've been already used
|
||||
// (i.e. its body has been read and sent to the client).
|
||||
const requestClone = request.clone()
|
||||
const requestClone = event.request.clone()
|
||||
|
||||
function passthrough() {
|
||||
// Cast the request headers to a new Headers instance
|
||||
@@ -230,29 +248,17 @@ async function getResponse(event, client, requestId) {
|
||||
}
|
||||
|
||||
// Notify the client that a request has been intercepted.
|
||||
const requestBuffer = await request.arrayBuffer()
|
||||
const serializedRequest = await serializeRequest(event.request)
|
||||
const clientMessage = await sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: requestBuffer,
|
||||
keepalive: request.keepalive,
|
||||
...serializedRequest,
|
||||
},
|
||||
},
|
||||
[requestBuffer],
|
||||
[serializedRequest.body],
|
||||
)
|
||||
|
||||
switch (clientMessage.type) {
|
||||
@@ -268,6 +274,12 @@ async function getResponse(event, client, requestId) {
|
||||
return passthrough()
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Client} client
|
||||
* @param {any} message
|
||||
* @param {Array<Transferable>} transferrables
|
||||
* @returns {Promise<any>}
|
||||
*/
|
||||
function sendToClient(client, message, transferrables = []) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const channel = new MessageChannel()
|
||||
@@ -280,14 +292,18 @@ function sendToClient(client, message, transferrables = []) {
|
||||
resolve(event.data)
|
||||
}
|
||||
|
||||
client.postMessage(
|
||||
message,
|
||||
[channel.port2].concat(transferrables.filter(Boolean)),
|
||||
)
|
||||
client.postMessage(message, [
|
||||
channel.port2,
|
||||
...transferrables.filter(Boolean),
|
||||
])
|
||||
})
|
||||
}
|
||||
|
||||
async function respondWithMock(response) {
|
||||
/**
|
||||
* @param {Response} response
|
||||
* @returns {Response}
|
||||
*/
|
||||
function respondWithMock(response) {
|
||||
// Setting response status code to 0 is a no-op.
|
||||
// However, when responding with a "Response.error()", the produced Response
|
||||
// instance will have status code set to 0. Since it's not possible to create
|
||||
@@ -305,3 +321,24 @@ async function respondWithMock(response) {
|
||||
|
||||
return mockedResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Request} request
|
||||
*/
|
||||
async function serializeRequest(request) {
|
||||
return {
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: await request.arrayBuffer(),
|
||||
keepalive: request.keepalive,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ from openhands.agenthub import ( # noqa: E402
|
||||
codeact_agent,
|
||||
dummy_agent,
|
||||
loc_agent,
|
||||
proxy_agent,
|
||||
readonly_agent,
|
||||
visualbrowsing_agent,
|
||||
)
|
||||
@@ -19,6 +20,7 @@ __all__ = [
|
||||
'dummy_agent',
|
||||
'browsing_agent',
|
||||
'visualbrowsing_agent',
|
||||
'proxy_agent',
|
||||
'readonly_agent',
|
||||
'loc_agent',
|
||||
]
|
||||
|
||||
54
openhands/agenthub/proxy_agent/README.md
Normal file
54
openhands/agenthub/proxy_agent/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Proxy Agent
|
||||
|
||||
This folder is an implementation of a Proxy Agent.
|
||||
The Proxy Agent delegates a given task to an appropriate agent capable of accomplishing it.
|
||||
The list of available agents is defined in agent_list.json, located in this directory.
|
||||
|
||||
A key feature of the Proxy Agent is that, in addition to delegating task to different agents available locally within OpenHands, it can also send messages to agents hosted on different server, using A2A Protocol.
|
||||
|
||||
## How to run
|
||||
### Set as the initial agent
|
||||
This agent is designed to be the initial agent that receives user input at the start of a session.
|
||||
Configure the Proxy Agent as the initial agent of a session.
|
||||
```mermaid
|
||||
flowchart LR
|
||||
u((User)) --> A
|
||||
|
||||
subgraph Server1
|
||||
A["Proxy Agent"]
|
||||
B["Other Agents<br>(e.g. CodeActAgent)"]
|
||||
A -->|delegate| B
|
||||
end
|
||||
|
||||
subgraph Server2
|
||||
D["Other Agents"]
|
||||
end
|
||||
|
||||
A --->|Remote Delegation| D
|
||||
|
||||
```
|
||||
|
||||
### Place agent_list.json
|
||||
Place agent_list.json under openhands/agenthub/proxy_agent. Below is an example of its structure:
|
||||
```json
|
||||
{
|
||||
"local": {
|
||||
"CodeActAgent": {
|
||||
"agent_name": "CodeActAgent",
|
||||
"description": "A helpful AI assistant that can interact with a computer to solve tasks."
|
||||
}
|
||||
},
|
||||
"remote": {
|
||||
"FooAgent": {
|
||||
"agent_name": "FooAgent",
|
||||
"url": "http(s)://IP or FQDN:port",
|
||||
"description": "A brief description of FooAgent.",
|
||||
"protocol": "A2A"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
The contents of this JSON file are simply passed as a string to the agent as part of its prompt, assisting the LLM in selecting the most suitable agent.
|
||||
Therefore, there areno strict formatting requirements, but please keep the following points in mind:
|
||||
- Clearly specify whether the agent is available locally within the same instance or hosted on a different instance.
|
||||
- If an agent is hosted on a different instance, explicitly provide the URL where that instance is hosted.
|
||||
4
openhands/agenthub/proxy_agent/__init__.py
Normal file
4
openhands/agenthub/proxy_agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from openhands.agenthub.proxy_agent.proxy_agent import ProxyAgent
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
Agent.register('ProxyAgent', ProxyAgent)
|
||||
180
openhands/agenthub/proxy_agent/function_calling.py
Normal file
180
openhands/agenthub/proxy_agent/function_calling.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
|
||||
from litellm import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
from openhands.core.exceptions import (
|
||||
FunctionCallNotExistsError,
|
||||
FunctionCallValidationError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
|
||||
_DELEGATE_LOCAL = """Delegate a task to a local agent hosted on a same instance.
|
||||
"""
|
||||
|
||||
DelegateLocalTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='delegate_local',
|
||||
description=_DELEGATE_LOCAL,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'agent_name': {
|
||||
'type': 'string',
|
||||
'description': 'The name of the agent to delegate to.',
|
||||
},
|
||||
'task': {
|
||||
'type': 'string',
|
||||
'description': 'The task to delegate.',
|
||||
},
|
||||
},
|
||||
'required': ['agent_name', 'task'],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
_DELEGATE_REMOTE = """Delegate a task to a remote agent hosted on a remote server using A2A Protocol.
|
||||
"""
|
||||
|
||||
DelegateRemoteTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='delegate_remote',
|
||||
description=_DELEGATE_REMOTE,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'url': {
|
||||
'type': 'string',
|
||||
'description': 'The URL of the remote agent.',
|
||||
},
|
||||
'task': {
|
||||
'type': 'string',
|
||||
'description': 'The task to delegate.',
|
||||
},
|
||||
'session_id': {
|
||||
'type': 'string',
|
||||
'description': 'The session id of the remote agent.',
|
||||
},
|
||||
'task_id': {
|
||||
'type': 'string',
|
||||
'description': 'The task id of the remote agent.',
|
||||
}
|
||||
},
|
||||
'required': ['url', 'task'],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
_FINISH_DESCRIPTION = """Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task."""
|
||||
|
||||
FinishTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='finish',
|
||||
description=_FINISH_DESCRIPTION,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def combine_thought(action: Action, thought: str) -> Action:
|
||||
if not hasattr(action, 'thought'):
|
||||
return action
|
||||
if thought:
|
||||
action.thought = thought
|
||||
return action
|
||||
|
||||
|
||||
def response_to_action(response: ModelResponse) -> Action:
|
||||
action: Action = None # type: ignore
|
||||
assert len(response.choices) == 1, 'Only one choice is supported for now'
|
||||
assistant_msg = response.choices[0].message
|
||||
if assistant_msg.tool_calls:
|
||||
# Check if there's assistant_msg.content. If so, add it to the thought
|
||||
thought = ''
|
||||
if isinstance(assistant_msg.content, str):
|
||||
thought = assistant_msg.content
|
||||
elif isinstance(assistant_msg.content, list):
|
||||
for msg in assistant_msg.content:
|
||||
if msg['type'] == 'text':
|
||||
thought += msg['text']
|
||||
|
||||
# Assume only one tool call is returned
|
||||
if len(assistant_msg.tool_calls) != 1:
|
||||
logger.info(
|
||||
f'Expected only one tool call, but got {len(assistant_msg.tool_calls)}'
|
||||
)
|
||||
tool_call = assistant_msg.tool_calls[0]
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f'Failed to parse tool call arguments: {tool_call.function.arguments}'
|
||||
) from e
|
||||
|
||||
if tool_call.function.name == 'delegate_remote':
|
||||
for k in ['url', 'task']:
|
||||
if k not in arguments:
|
||||
raise FunctionCallValidationError(
|
||||
f'Missing required argument "{k}" in tool call {tool_call.function.name}'
|
||||
)
|
||||
|
||||
message = arguments['task']
|
||||
message = message.replace('\n', '\\\n')
|
||||
url = arguments['url']
|
||||
session_id = arguments.get('session_id')
|
||||
task_id = arguments.get('task_id')
|
||||
if session_id and task_id:
|
||||
code = (
|
||||
f'await send_task_A2A('
|
||||
f'message="{message}", '
|
||||
f'url="{url}", '
|
||||
f'session_id="{session_id}", '
|
||||
f'task_id="{task_id}")'
|
||||
)
|
||||
else:
|
||||
code = (
|
||||
f'await send_task_A2A('
|
||||
f'message="{message}", '
|
||||
f'url="{url}")'
|
||||
)
|
||||
|
||||
action = IPythonRunCellAction(code=code, include_extra=False)
|
||||
|
||||
elif tool_call.function.name == 'finish':
|
||||
action = AgentFinishAction()
|
||||
else:
|
||||
raise FunctionCallNotExistsError(
|
||||
f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.'
|
||||
)
|
||||
|
||||
action = combine_thought(action, thought)
|
||||
# Add metadata for tool calling
|
||||
action.tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_call.function.name,
|
||||
model_response=response,
|
||||
total_calls_in_response=len(assistant_msg.tool_calls),
|
||||
)
|
||||
|
||||
else:
|
||||
action = MessageAction(content=assistant_msg.content, wait_for_response=True)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def get_tools() -> list[ChatCompletionToolParam]:
|
||||
tools = [DelegateLocalTool, DelegateRemoteTool, FinishTool]
|
||||
return tools
|
||||
6
openhands/agenthub/proxy_agent/prompts/system_prompt.j2
Normal file
6
openhands/agenthub/proxy_agent/prompts/system_prompt.j2
Normal file
@@ -0,0 +1,6 @@
|
||||
You are a Proxy Agent, a helpful AI assistant which is responsible for delegating tasks to other agents.
|
||||
You delegate tasks to agents that exist locally or are hosted remotely on another server.
|
||||
<IMPORTANT>
|
||||
* Never execute an action again once the action has been completed.
|
||||
* When you delegate a task to a remote-host agent, you must read the response of the remote agent and return a message to the user as if you were that agent.
|
||||
</IMPORTANT>
|
||||
126
openhands/agenthub/proxy_agent/proxy_agent.py
Normal file
126
openhands/agenthub/proxy_agent/proxy_agent.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import openhands.agenthub.proxy_agent.function_calling as proxy_function_calling
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action import Action, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.conversation_memory import ConversationMemory
|
||||
from openhands.microagent.prompt_manager import PromptManager
|
||||
from openhands.runtime.plugins import (
|
||||
AgentSkillsRequirement,
|
||||
JupyterRequirement,
|
||||
PluginRequirement,
|
||||
)
|
||||
|
||||
|
||||
class ProxyAgent(Agent):
|
||||
sandbox_plugins: list[PluginRequirement] = [
|
||||
AgentSkillsRequirement(),
|
||||
JupyterRequirement(),
|
||||
]
|
||||
|
||||
def __init__(self, llm: LLM, config: AgentConfig) -> None:
|
||||
super().__init__(llm, config)
|
||||
self.reset()
|
||||
|
||||
self.mock_function_calling = False
|
||||
if not self.llm.is_function_calling_active():
|
||||
logger.info(
|
||||
f'Function calling not enabled for model {self.llm.config.model}. '
|
||||
'Mocking function calling via prompting.'
|
||||
)
|
||||
self.mock_function_calling = True
|
||||
|
||||
# Function calling mode
|
||||
self.tools = proxy_function_calling.get_tools()
|
||||
|
||||
self._prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
)
|
||||
|
||||
# Create a ConversationMemory instance
|
||||
# _prompt_manager is guaranteed to be set at this point
|
||||
assert self._prompt_manager is not None
|
||||
self.conversation_memory = ConversationMemory(self.config, self._prompt_manager)
|
||||
|
||||
agent_list_path = os.path.join(os.path.dirname(__file__), 'agent_list.json')
|
||||
if not os.path.exists(agent_list_path):
|
||||
raise FileNotFoundError('agent list file not found')
|
||||
with open(agent_list_path, 'r') as f:
|
||||
self.agent_list = json.load(f)
|
||||
if self.agent_list == {}:
|
||||
raise ValueError('agent list file is empty')
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
# Prepare the message to send to the LLM
|
||||
initial_user_message = self._get_initial_user_message(state.history)
|
||||
messages = self._get_messages(state.history, initial_user_message)
|
||||
|
||||
params: dict = {
|
||||
'messages': self.llm.format_messages_for_llm(messages),
|
||||
}
|
||||
params['tools'] = self.tools
|
||||
if self.mock_function_calling:
|
||||
params['mock_function_calling'] = True
|
||||
response = self.llm.completion(**params)
|
||||
|
||||
# Assume only one tool call is returned
|
||||
action = proxy_function_calling.response_to_action(response)
|
||||
return action
|
||||
|
||||
def _get_initial_user_message(self, history: list[Event]) -> MessageAction:
|
||||
"""Finds the initial user message action from the full history."""
|
||||
initial_user_message: MessageAction | None = None
|
||||
for event in history:
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
initial_user_message = event
|
||||
break
|
||||
|
||||
if initial_user_message is None:
|
||||
# This should not happen in a valid conversation
|
||||
raise ValueError(
|
||||
'Initial user message not found in history. Please report this issue.'
|
||||
)
|
||||
return initial_user_message
|
||||
|
||||
def _get_messages(
|
||||
self, events: list[Event], initial_user_message: MessageAction
|
||||
) -> list[Message]:
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
# Use ConversationMemory to process events (including SystemMessageAction)
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
)
|
||||
|
||||
agent_list_message = Message(
|
||||
role='system',
|
||||
content=[
|
||||
TextContent(
|
||||
text='Available agents are the following:'
|
||||
+ json.dumps(self.agent_list)
|
||||
)
|
||||
],
|
||||
)
|
||||
if len(messages) > 1:
|
||||
messages.insert(1, agent_list_message)
|
||||
else:
|
||||
messages.append(agent_list_message)
|
||||
|
||||
if self.llm.is_caching_prompt_active():
|
||||
self.conversation_memory.apply_prompt_caching(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
11
openhands/runtime/plugins/agent_skills/a2a_client/README.md
Normal file
11
openhands/runtime/plugins/agent_skills/a2a_client/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# A2A Client
|
||||
This is an implementation of an A2A Client, called by agents within runtime container.
|
||||
|
||||
This directory contains code from [A2A](https://github.com/google/A2A), originally licensed under the Apache License 2.0.
|
||||
The original source has been modified to fit the needs of this project.
|
||||
See third_party_license/LICENSE for the full license text.
|
||||
|
||||
## Modifications
|
||||
|
||||
- Removed unused components (e.g. PushNotfication) from original code.
|
||||
- Implemented 'send_task_a2a' with customed I/O to make it more convenient for AI Agent
|
||||
@@ -0,0 +1,9 @@
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client import a2a_client
|
||||
from openhands.runtime.plugins.agent_skills.utils.dependency import import_functions
|
||||
|
||||
import_functions(
|
||||
module=a2a_client,
|
||||
function_names=a2a_client.__all__,
|
||||
target_globals=globals(),
|
||||
)
|
||||
__all__ = a2a_client.__all__
|
||||
@@ -0,0 +1,80 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.client import (
|
||||
A2ACardResolver,
|
||||
A2AClient,
|
||||
)
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
TaskState,
|
||||
)
|
||||
|
||||
|
||||
async def send_task_A2A(url, message, session_id=0, task_id=0):
|
||||
"""
|
||||
Send a task to an agent hosted on remote server, compatible with A2A protocol.
|
||||
"""
|
||||
## Get the agent card
|
||||
card_resolver = A2ACardResolver(url)
|
||||
card = card_resolver.get_agent_card()
|
||||
|
||||
print('======= Agent Card ========')
|
||||
print(card.model_dump_json(exclude_none=True))
|
||||
|
||||
client = A2AClient(agent_card=card)
|
||||
|
||||
if session_id == 0:
|
||||
session_id = uuid4().hex
|
||||
if task_id == 0:
|
||||
task_id = uuid4().hex
|
||||
|
||||
streaming = card.capabilities.streaming
|
||||
print('======= Session ID and Task ID ========')
|
||||
print(f'Session ID: {session_id}')
|
||||
print(f'Task ID: {task_id}')
|
||||
print('If you want to send more input, use the same session ID and task ID.')
|
||||
|
||||
print('========= starting a task ======== ')
|
||||
await completeTask(client, message, streaming, task_id, session_id)
|
||||
|
||||
|
||||
async def completeTask(client: A2AClient, message, streaming, task_id, session_id):
|
||||
prompt = message
|
||||
|
||||
message = {
|
||||
'role': 'user',
|
||||
'parts': [
|
||||
{
|
||||
'type': 'text',
|
||||
'text': prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
payload = {
|
||||
'id': task_id,
|
||||
'sessionId': session_id,
|
||||
'acceptedOutputModes': ['text'],
|
||||
'message': message,
|
||||
}
|
||||
|
||||
taskResult = None
|
||||
if streaming:
|
||||
response_stream = client.send_task_streaming(payload)
|
||||
async for result in response_stream:
|
||||
print(f'stream event => {result.model_dump_json(exclude_none=True)}')
|
||||
taskResult = await client.get_task({'id': task_id})
|
||||
else:
|
||||
taskResult = await client.send_task(payload)
|
||||
print(f'\n{taskResult.model_dump_json(exclude_none=True)}')
|
||||
|
||||
## if the result is that more input is required, tell the user and exit.
|
||||
if taskResult.result:
|
||||
state = TaskState(taskResult.result.status.state)
|
||||
if state.name == TaskState.INPUT_REQUIRED.name:
|
||||
print('Task requires more input. Use this tool again to provide it.')
|
||||
else:
|
||||
## task is complete
|
||||
return True
|
||||
|
||||
|
||||
__all__ = ['send_task_A2A']
|
||||
@@ -0,0 +1,4 @@
|
||||
from .client import A2AClient
|
||||
from .card_resolver import A2ACardResolver
|
||||
|
||||
__all__ = ["A2AClient", "A2ACardResolver"]
|
||||
@@ -0,0 +1,21 @@
|
||||
import httpx
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
AgentCard,
|
||||
A2AClientJSONError,
|
||||
)
|
||||
import json
|
||||
|
||||
|
||||
class A2ACardResolver:
|
||||
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_card_path = agent_card_path.lstrip("/")
|
||||
|
||||
def get_agent_card(self) -> AgentCard:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(self.base_url + "/" + self.agent_card_path)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
return AgentCard(**response.json())
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterable
|
||||
|
||||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
AgentCard,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
JSONRPCRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, agent_card: AgentCard | None = None, url: str | None = None):
|
||||
if agent_card:
|
||||
self.url = agent_card.url
|
||||
elif url:
|
||||
self.url = url
|
||||
else:
|
||||
raise ValueError('Must provide either agent_card or url')
|
||||
|
||||
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||
request = SendTaskRequest(params=payload)
|
||||
return SendTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def send_task_streaming(
|
||||
self, payload: dict[str, Any]
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
request = SendTaskStreamingRequest(params=payload)
|
||||
with httpx.Client(timeout=None) as client:
|
||||
with connect_sse(
|
||||
client, 'POST', self.url, json=request.model_dump()
|
||||
) as event_source:
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
except httpx.RequestError as e:
|
||||
raise A2AClientHTTPError(400, str(e)) from e
|
||||
|
||||
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Image generation could take time, adding timeout
|
||||
response = await client.post(
|
||||
self.url, json=request.model_dump(), timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
||||
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||
request = GetTaskRequest(params=payload)
|
||||
return GetTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||
request = CancelTaskRequest(params=payload)
|
||||
return CancelTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def set_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
request = SetTaskPushNotificationRequest(params=payload)
|
||||
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
||||
async def get_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
request = GetTaskPushNotificationRequest(params=payload)
|
||||
return GetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
@@ -0,0 +1,365 @@
|
||||
from typing import Union, Any
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from typing import Literal, List, Annotated, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import model_validator, ConfigDict, field_serializer
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
SUBMITTED = "submitted"
|
||||
WORKING = "working"
|
||||
INPUT_REQUIRED = "input-required"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
name: str | None = None
|
||||
mimeType: str | None = None
|
||||
bytes: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content(self) -> Self:
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
|
||||
if self.bytes and self.uri:
|
||||
raise ValueError(
|
||||
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
type: Literal["file"] = "file"
|
||||
file: FileContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
type: Literal["data"] = "data"
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal["user", "agent"]
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
state: TaskState
|
||||
message: Message | None = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@field_serializer("timestamp")
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
index: int = 0
|
||||
append: bool | None = None
|
||||
lastChunk: bool | None = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
sessionId: str | None = None
|
||||
status: TaskStatus
|
||||
artifacts: List[Artifact] | None = None
|
||||
history: List[Message] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
url: str
|
||||
token: str | None = None
|
||||
authentication: AuthenticationInfo | None = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
historyLength: int | None = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
pushNotification: PushNotificationConfig | None = None
|
||||
historyLength: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
## RPC Messages
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
jsonrpc: Literal["2.0"] = "2.0"
|
||||
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
method: str
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
result: Any | None = None
|
||||
error: JSONRPCError | None = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/send"] = "tasks/send"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/get"] = "tasks/get"
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/cancel",] = "tasks/cancel"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
Field(discriminator="method"),
|
||||
]
|
||||
)
|
||||
|
||||
## Error types
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
code: int = -32700
|
||||
message: str = "Invalid JSON payload"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
code: int = -32600
|
||||
message: str = "Request payload validation error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
code: int = -32601
|
||||
message: str = "Method not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
code: int = -32602
|
||||
message: str = "Invalid parameters"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
code: int = -32603
|
||||
message: str = "Internal error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
code: int = -32001
|
||||
message: str = "Task not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
code: int = -32002
|
||||
message: str = "Task cannot be canceled"
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
code: int = -32003
|
||||
message: str = "Push Notification is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
code: int = -32004
|
||||
message: str = "This operation is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
code: int = -32005
|
||||
message: str = "Incompatible content types"
|
||||
data: None = None
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
organization: str
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: List[str] | None = None
|
||||
examples: List[str] | None = None
|
||||
inputModes: List[str] | None = None
|
||||
outputModes: List[str] | None = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
url: str
|
||||
provider: AgentProvider | None = None
|
||||
version: str
|
||||
documentationUrl: str | None = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: AgentAuthentication | None = None
|
||||
defaultInputModes: List[str] = ["text"]
|
||||
defaultOutputModes: List[str] = ["text"]
|
||||
skills: List[AgentSkill]
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f"JSON Error: {message}")
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""Exception for missing API key."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,6 +1,10 @@
|
||||
from inspect import signature
|
||||
|
||||
from openhands.runtime.plugins.agent_skills import file_ops, file_reader
|
||||
from openhands.runtime.plugins.agent_skills import (
|
||||
a2a_client,
|
||||
file_ops,
|
||||
file_reader,
|
||||
)
|
||||
from openhands.runtime.plugins.agent_skills.utils.dependency import import_functions
|
||||
|
||||
import_functions(
|
||||
@@ -9,8 +13,13 @@ import_functions(
|
||||
import_functions(
|
||||
module=file_reader, function_names=file_reader.__all__, target_globals=globals()
|
||||
)
|
||||
import_functions(
|
||||
module=a2a_client,
|
||||
function_names=a2a_client.__all__,
|
||||
target_globals=globals(),
|
||||
)
|
||||
|
||||
__all__ = file_ops.__all__ + file_reader.__all__
|
||||
__all__ = file_ops.__all__ + file_reader.__all__ + a2a_client.__all__
|
||||
|
||||
try:
|
||||
from openhands.runtime.plugins.agent_skills import repo_ops
|
||||
|
||||
@@ -117,7 +117,7 @@ RUN /openhands/micromamba/bin/micromamba run -n openhands poetry install --only
|
||||
|
||||
# Install playwright and its dependencies
|
||||
RUN apt-get update && \
|
||||
/openhands/micromamba/bin/micromamba run -n openhands poetry run pip install playwright && \
|
||||
/openhands/micromamba/bin/micromamba run -n openhands poetry run pip install playwright httpx httpx-sse pydantic && \
|
||||
/openhands/micromamba/bin/micromamba run -n openhands poetry run playwright install --with-deps chromium
|
||||
|
||||
# Set environment variables and permissions
|
||||
|
||||
23
poetry.lock
generated
23
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@@ -400,7 +400,7 @@ description = "LTS Port of Python audioop"
|
||||
optional = false
|
||||
python-versions = ">=3.13"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "audioop_lts-0.2.1-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd1345ae99e17e6910f47ce7d52673c6a1a70820d78b67de1b7abb3af29c426a"},
|
||||
{file = "audioop_lts-0.2.1-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:e175350da05d2087e12cea8e72a70a1a8b14a17e92ed2022952a4419689ede5e"},
|
||||
@@ -1580,7 +1580,7 @@ files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\" or os_name == \"nt\"", dev = "os_name == \"nt\" or sys_platform == \"win32\"", runtime = "sys_platform == \"win32\"", test = "platform_system == \"Windows\" or sys_platform == \"win32\""}
|
||||
markers = {main = "platform_system == \"Windows\" or os_name == \"nt\" or sys_platform == \"win32\"", dev = "os_name == \"nt\" or sys_platform == \"win32\"", runtime = "sys_platform == \"win32\"", test = "platform_system == \"Windows\" or sys_platform == \"win32\""}
|
||||
|
||||
[[package]]
|
||||
name = "comm"
|
||||
@@ -2974,8 +2974,8 @@ files = [
|
||||
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
|
||||
|
||||
@@ -2997,8 +2997,8 @@ googleapis-common-protos = ">=1.56.2,<2.0.0"
|
||||
grpcio = {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
grpcio-status = {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
requests = ">=2.18.0,<3.0.0"
|
||||
@@ -3216,8 +3216,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
|
||||
grpc-google-iam-v1 = ">=0.14.0,<1.0.0"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
|
||||
@@ -5422,7 +5422,7 @@ version = "0.61.0"
|
||||
description = "A module for monitoring memory usage of a python program"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
groups = ["runtime"]
|
||||
groups = ["main", "runtime"]
|
||||
files = [
|
||||
{file = "memory_profiler-0.61.0-py3-none-any.whl", hash = "sha256:400348e61031e3942ad4d4109d18753b2fb08c2f6fb8290671c5513a34182d84"},
|
||||
{file = "memory_profiler-0.61.0.tar.gz", hash = "sha256:4e5b73d7864a1d1292fb76a03e82a3e78ef934d06828a698d9dada76da2067b0"},
|
||||
@@ -6479,8 +6479,8 @@ files = [
|
||||
[package.dependencies]
|
||||
googleapis-common-protos = ">=1.52,<2.0"
|
||||
grpcio = [
|
||||
{version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""},
|
||||
{version = ">=1.66.2,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""},
|
||||
]
|
||||
opentelemetry-api = ">=1.15,<2.0"
|
||||
opentelemetry-exporter-otlp-proto-common = "1.34.1"
|
||||
@@ -9243,7 +9243,6 @@ files = [
|
||||
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
|
||||
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
|
||||
]
|
||||
markers = {evaluation = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||
@@ -9486,7 +9485,7 @@ description = "Standard library aifc redistribution. \"dead battery\"."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "standard_aifc-3.13.0-py3-none-any.whl", hash = "sha256:f7ae09cc57de1224a0dd8e3eb8f73830be7c3d0bc485de4c1f82b4a7f645ac66"},
|
||||
{file = "standard_aifc-3.13.0.tar.gz", hash = "sha256:64e249c7cb4b3daf2fdba4e95721f811bde8bdfc43ad9f936589b7bb2fae2e43"},
|
||||
@@ -9503,7 +9502,7 @@ description = "Standard library chunk redistribution. \"dead battery\"."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "standard_chunk-3.13.0-py3-none-any.whl", hash = "sha256:17880a26c285189c644bd5bd8f8ed2bdb795d216e3293e6dbe55bbd848e2982c"},
|
||||
{file = "standard_chunk-3.13.0.tar.gz", hash = "sha256:4ac345d37d7e686d2755e01836b8d98eda0d1a3ee90375e597ae43aaf064d654"},
|
||||
@@ -11665,4 +11664,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "47df4fc76b97147ff31169028edafaf35c1f4e661c7ab74bad48cb0ceea06aba"
|
||||
content-hash = "0b8da1a7da2d598f9ca4a8933245c99495f7a34bb26e1221eebd7ba2fa1d6ddc"
|
||||
|
||||
@@ -71,6 +71,11 @@ python-frontmatter = "^1.1.0"
|
||||
# TODO: Should these go into the runtime group?
|
||||
ipywidgets = "^8.1.5"
|
||||
qtconsole = "^5.6.1"
|
||||
memory-profiler = "^0.61.0"
|
||||
playwright = "^1.51.0"
|
||||
pydantic = "^2.11.3"
|
||||
httpx = "^0.28.1"
|
||||
httpx-sse = "^0.4.0"
|
||||
PyPDF2 = "*"
|
||||
python-pptx = "*"
|
||||
pylatexenc = "*"
|
||||
|
||||
471
tests/unit/test_a2a_client.py
Normal file
471
tests/unit/test_a2a_client.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import HTTPStatusError, Request, Response
|
||||
from httpx_sse import ServerSentEvent
|
||||
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.a2a_client import (
|
||||
completeTask,
|
||||
send_task_A2A,
|
||||
)
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.client import (
|
||||
A2ACardResolver,
|
||||
A2AClient,
|
||||
)
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
AgentCard,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
JSONRPCRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
TaskIdParams,
|
||||
TaskPushNotificationConfig,
|
||||
TaskSendParams,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
# Tests for openhands/runtime/plugins/agent_skills/a2a_client/a2a_client.py
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.runtime.plugins.agent_skills.a2a_client.a2a_client.A2ACardResolver')
|
||||
@patch('openhands.runtime.plugins.agent_skills.a2a_client.a2a_client.A2AClient')
|
||||
async def test_send_task_A2A(mock_a2a_client, mock_card_resolver):
|
||||
# Mock: card resolver, agent card, A2A Client and completeTask
|
||||
mock_card = Mock()
|
||||
mock_card.capabilities.streaming = False
|
||||
mock_card.model_dump_json.return_value = '{}'
|
||||
mock_card_resolver.return_value.get_agent_card.return_value = mock_card
|
||||
|
||||
mock_client = Mock()
|
||||
mock_a2a_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
'openhands.runtime.plugins.agent_skills.a2a_client.a2a_client.completeTask',
|
||||
new=AsyncMock(),
|
||||
) as mock_complete_task:
|
||||
await send_task_A2A('http://example.com', 'test message')
|
||||
|
||||
mock_card_resolver.assert_called_once_with('http://example.com')
|
||||
mock_card_resolver.return_value.get_agent_card.assert_called_once()
|
||||
mock_a2a_client.assert_called_once_with(agent_card=mock_card)
|
||||
|
||||
mock_complete_task.assert_called_once()
|
||||
_, _, streaming, task_id, session_id = mock_complete_task.call_args[0]
|
||||
assert streaming is False
|
||||
assert len(task_id) > 0
|
||||
assert len(session_id) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.runtime.plugins.agent_skills.a2a_client.a2a_client.A2AClient')
|
||||
async def test_completeTask_non_streaming(mock_a2a_client):
|
||||
# Mock: A2AClient
|
||||
mock_client = Mock()
|
||||
mock_client.send_task = AsyncMock(
|
||||
return_value=Mock(result=Mock(status=Mock(state=TaskState.COMPLETED.value)))
|
||||
)
|
||||
mock_a2a_client.return_value = mock_client
|
||||
|
||||
result = await completeTask(
|
||||
mock_client, 'test message', False, 'task_id', 'session_id'
|
||||
)
|
||||
|
||||
mock_client.send_task.assert_called_once()
|
||||
payload = mock_client.send_task.call_args[0][0]
|
||||
assert payload['id'] == 'task_id'
|
||||
assert payload['sessionId'] == 'session_id'
|
||||
assert payload['message']['role'] == 'user'
|
||||
assert payload['message']['parts'][0]['text'] == 'test message'
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.runtime.plugins.agent_skills.a2a_client.a2a_client.A2AClient')
|
||||
async def test_completeTask_streaming(mock_a2a_client):
|
||||
# Mock A2AClient
|
||||
mock_client = Mock()
|
||||
mock_client.send_task_streaming = Mock(return_value=AsyncMock())
|
||||
mock_client.get_task = AsyncMock(
|
||||
return_value=Mock(result=Mock(status=Mock(state=TaskState.COMPLETED.value)))
|
||||
)
|
||||
mock_a2a_client.return_value = mock_client
|
||||
|
||||
result = await completeTask(
|
||||
mock_client, 'test message', True, 'task_id', 'session_id'
|
||||
)
|
||||
|
||||
mock_client.send_task_streaming.assert_called_once()
|
||||
mock_client.get_task.assert_called_once_with({'id': 'task_id'})
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.runtime.plugins.agent_skills.a2a_client.a2a_client.A2AClient')
|
||||
async def test_completeTask_input_required(mock_a2a_client):
|
||||
# Mock A2AClient
|
||||
mock_client = Mock()
|
||||
mock_client.send_task = AsyncMock(
|
||||
return_value=Mock(
|
||||
result=Mock(status=Mock(state=TaskState.INPUT_REQUIRED.value))
|
||||
)
|
||||
)
|
||||
mock_a2a_client.return_value = mock_client
|
||||
|
||||
result = await completeTask(
|
||||
mock_client, 'test message', False, 'task_id', 'session_id'
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# Tests for openhands/runtime/plugins/agent_skills/a2a_client/common/client/client.py
|
||||
TEST_URL = 'https://example.com'
|
||||
|
||||
|
||||
def make_mock_request() -> JSONRPCRequest:
|
||||
return JSONRPCRequest(
|
||||
jsonrpc='2.0',
|
||||
id='1',
|
||||
method='test_method',
|
||||
params={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__send_request_success():
|
||||
mock_request = make_mock_request()
|
||||
mock_response_data = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'method': 'test_method',
|
||||
'result': {},
|
||||
}
|
||||
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient.post', mock_post):
|
||||
result = await client._send_request(mock_request)
|
||||
|
||||
assert result == mock_response_data
|
||||
mock_post.assert_awaited_once_with(
|
||||
TEST_URL, json=mock_request.model_dump(), timeout=30
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__send_request_http_error():
|
||||
mock_request = make_mock_request()
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
bad_request_response = Response(status_code=400, request=Request('POST', TEST_URL))
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = HTTPStatusError(
|
||||
message='Bad Request',
|
||||
request=bad_request_response.request,
|
||||
response=bad_request_response,
|
||||
)
|
||||
mock_response.json.return_value = {}
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient.post', mock_post):
|
||||
with pytest.raises(A2AClientHTTPError) as exc_info:
|
||||
await client._send_request(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'Bad Request' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__send_request_json_error():
|
||||
mock_request = make_mock_request()
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_response.json.side_effect = json.JSONDecodeError(
|
||||
msg='Expecting value', doc='', pos=0
|
||||
)
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient.post', mock_post):
|
||||
with pytest.raises(A2AClientJSONError) as exc_info:
|
||||
await client._send_request(mock_request)
|
||||
|
||||
assert 'Expecting value' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_task():
|
||||
mock_payload = {
|
||||
'id': 'test-task-id',
|
||||
'sessionId': 'test-session-id',
|
||||
'message': {
|
||||
'role': 'user',
|
||||
'parts': [
|
||||
{
|
||||
'type': 'text',
|
||||
'text': 'Hello, world!',
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
mock_response_data = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'result': {
|
||||
'id': 'test-task-id',
|
||||
'sessionId': 'test-session-id',
|
||||
'status': {
|
||||
'state': 'submitted',
|
||||
'timestamp': '2025-01-01T00:00:00Z',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
with patch.object(
|
||||
client, '_send_request', AsyncMock(return_value=mock_response_data)
|
||||
) as mock_send:
|
||||
result = await client.send_task(mock_payload)
|
||||
|
||||
assert isinstance(result, SendTaskResponse)
|
||||
assert result == SendTaskResponse(**mock_response_data)
|
||||
mock_send.assert_awaited_once()
|
||||
sent_arg = mock_send.call_args.args[0]
|
||||
assert isinstance(sent_arg, SendTaskRequest)
|
||||
assert sent_arg.params == TaskSendParams(**mock_payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_task():
|
||||
mock_payload = {'id': 'test-task-id'}
|
||||
mock_response_data = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'result': {
|
||||
'id': 'test-task-id',
|
||||
'status': {
|
||||
'state': 'canceled',
|
||||
'timestamp': '2025-01-01T00:00:00Z',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
with patch.object(
|
||||
client, '_send_request', AsyncMock(return_value=mock_response_data)
|
||||
) as mock_send:
|
||||
result = await client.cancel_task(mock_payload)
|
||||
|
||||
assert isinstance(result, CancelTaskResponse)
|
||||
assert result == CancelTaskResponse(**mock_response_data)
|
||||
mock_send.assert_awaited_once()
|
||||
sent_arg = mock_send.call_args.args[0]
|
||||
assert isinstance(sent_arg, CancelTaskRequest)
|
||||
assert sent_arg.params == TaskIdParams(**mock_payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_task_callback():
|
||||
mock_payload = {
|
||||
'id': 'test-task-id',
|
||||
'pushNotificationConfig': {'url': 'https://callback.example.com'},
|
||||
}
|
||||
mock_response_data = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'result': {
|
||||
'id': 'test-task-id',
|
||||
'pushNotificationConfig': {
|
||||
'url': 'https://callback.example.com',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
with patch.object(
|
||||
client, '_send_request', AsyncMock(return_value=mock_response_data)
|
||||
) as mock_send:
|
||||
result = await client.set_task_callback(mock_payload)
|
||||
|
||||
assert isinstance(result, SetTaskPushNotificationResponse)
|
||||
assert result == SetTaskPushNotificationResponse(**mock_response_data)
|
||||
mock_send.assert_awaited_once()
|
||||
sent_arg = mock_send.call_args.args[0]
|
||||
assert isinstance(sent_arg, SetTaskPushNotificationRequest)
|
||||
assert sent_arg.params == TaskPushNotificationConfig(**mock_payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_callback():
|
||||
mock_payload = {'id': 'test-task-id'}
|
||||
mock_response_data = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'result': {
|
||||
'id': 'test-task-id',
|
||||
'pushNotificationConfig': {
|
||||
'url': 'https://callback.example.com',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client = A2AClient(url=TEST_URL)
|
||||
|
||||
with patch.object(
|
||||
client, '_send_request', AsyncMock(return_value=mock_response_data)
|
||||
) as mock_send:
|
||||
result = await client.get_task_callback(mock_payload)
|
||||
|
||||
assert isinstance(result, GetTaskPushNotificationResponse)
|
||||
assert result == GetTaskPushNotificationResponse(**mock_response_data)
|
||||
mock_send.assert_awaited_once()
|
||||
sent_arg = mock_send.call_args.args[0]
|
||||
assert isinstance(sent_arg, GetTaskPushNotificationRequest)
|
||||
assert sent_arg.params == TaskIdParams(**mock_payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_task_streaming():
|
||||
mock_event_source = MagicMock()
|
||||
mock_event_source.iter_sse.return_value = iter(
|
||||
[
|
||||
ServerSentEvent(
|
||||
data=json.dumps(
|
||||
{
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'result': {
|
||||
'id': 'test-task-id',
|
||||
'status': {
|
||||
'state': 'submitted',
|
||||
'timestamp': '2025-01-01T00:00:00Z',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
ServerSentEvent(
|
||||
data=json.dumps(
|
||||
{
|
||||
'jsonrpc': '2.0',
|
||||
'id': '1',
|
||||
'result': {
|
||||
'id': 'test-task-id',
|
||||
'status': {
|
||||
'state': 'working',
|
||||
'timestamp': '2025-01-01T00:01:00Z',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
'openhands.runtime.plugins.agent_skills.a2a_client.common.client.client.connect_sse'
|
||||
) as mock_connect_sse:
|
||||
mock_connect_sse.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
client = A2AClient(url='http://mock.url')
|
||||
payload = {
|
||||
'id': 'test-task-id',
|
||||
'message': {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hi'}]},
|
||||
}
|
||||
|
||||
responses = []
|
||||
async for res in client.send_task_streaming(payload):
|
||||
responses.append(res)
|
||||
|
||||
assert responses[0].result == TaskStatusUpdateEvent(
|
||||
id='test-task-id',
|
||||
status=TaskStatus(
|
||||
state='submitted',
|
||||
timestamp='2025-01-01T00:00:00Z',
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
|
||||
assert responses[1].result == TaskStatusUpdateEvent(
|
||||
id='test-task-id',
|
||||
status=TaskStatus(
|
||||
state='working',
|
||||
timestamp='2025-01-01T00:01:00Z',
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
|
||||
|
||||
# Tests for openhands/runtime/plugins/agent_skills/a2a_client/common/card_resolver.py
|
||||
|
||||
|
||||
def test_get_agent_card_success():
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'name': 'TestAgent',
|
||||
'version': '1.0',
|
||||
'capabilities': {
|
||||
'streaming': True,
|
||||
},
|
||||
'url': 'http://example.com',
|
||||
'skills': [
|
||||
{
|
||||
'id': 'test_skill_id',
|
||||
'name': 'test_skill_name',
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch(
|
||||
'openhands.runtime.plugins.agent_skills.a2a_client.common.client.client.httpx.Client'
|
||||
) as mock_client_cls:
|
||||
mock_client = mock_client_cls.return_value.__enter__.return_value
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
resolver = A2ACardResolver('http://example.com')
|
||||
result = resolver.get_agent_card()
|
||||
|
||||
assert isinstance(result, AgentCard)
|
||||
assert result.name == 'TestAgent'
|
||||
assert result.version == '1.0'
|
||||
|
||||
|
||||
def test_get_agent_card_json_error():
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.side_effect = json.JSONDecodeError('msg', 'doc', 0)
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch(
|
||||
'openhands.runtime.plugins.agent_skills.a2a_client.common.client.client.httpx.Client'
|
||||
) as mock_client_cls:
|
||||
mock_client = mock_client_cls.return_value.__enter__.return_value
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
resolver = A2ACardResolver('http://example.com')
|
||||
with pytest.raises(A2AClientJSONError):
|
||||
resolver.get_agent_card()
|
||||
133
tests/unit/test_proxy_agent.py
Normal file
133
tests/unit/test_proxy_agent.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.agenthub.proxy_agent.function_calling import (
|
||||
DelegateLocalTool,
|
||||
DelegateRemoteTool,
|
||||
FinishTool,
|
||||
get_tools,
|
||||
response_to_action,
|
||||
)
|
||||
from openhands.agenthub.proxy_agent.proxy_agent import ProxyAgent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.exceptions import FunctionCallNotExistsError
|
||||
from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent() -> ProxyAgent:
|
||||
config = AgentConfig()
|
||||
agent = ProxyAgent(llm=LLM(LLMConfig()), config=config)
|
||||
agent.llm = Mock()
|
||||
agent.llm.config = Mock()
|
||||
agent.llm.config.max_message_chars = 1000
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state() -> State:
|
||||
state = Mock(spec=State)
|
||||
state.history = []
|
||||
state.extra_data = {}
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def test_get_tools():
|
||||
tools = get_tools()
|
||||
|
||||
assert len(tools) > 0
|
||||
|
||||
# Check required tools are present
|
||||
tool_names = [tool['function']['name'] for tool in tools]
|
||||
assert 'delegate_local' in tool_names
|
||||
assert 'delegate_remote' in tool_names
|
||||
assert 'finish' in tool_names
|
||||
|
||||
|
||||
def test_delegate_local_tool():
|
||||
assert DelegateLocalTool['type'] == 'function'
|
||||
assert DelegateLocalTool['function']['name'] == 'delegate_local'
|
||||
assert list(DelegateLocalTool['function']['parameters']['properties'].keys()) == [
|
||||
'agent_name',
|
||||
'task',
|
||||
]
|
||||
assert DelegateLocalTool['function']['parameters']['required'] == [
|
||||
'agent_name',
|
||||
'task',
|
||||
]
|
||||
|
||||
|
||||
def test_delegate_remote_tool():
|
||||
assert DelegateRemoteTool['type'] == 'function'
|
||||
assert DelegateRemoteTool['function']['name'] == 'delegate_remote'
|
||||
assert list(DelegateRemoteTool['function']['parameters']['properties'].keys()) == [
|
||||
'url',
|
||||
'task',
|
||||
'session_id',
|
||||
'task_id',
|
||||
]
|
||||
assert DelegateRemoteTool['function']['parameters']['required'] == [
|
||||
'url',
|
||||
'task',
|
||||
]
|
||||
|
||||
|
||||
def test_finish_tool():
|
||||
assert FinishTool['type'] == 'function'
|
||||
assert FinishTool['function']['name'] == 'finish'
|
||||
|
||||
|
||||
def test_response_to_action_invalid_tool():
|
||||
# Test response with invalid tool call
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock()]
|
||||
mock_response.choices[0].message = Mock()
|
||||
mock_response.choices[0].message.content = 'Invalid tool'
|
||||
mock_response.choices[0].message.tool_calls = [Mock()]
|
||||
mock_response.choices[0].message.tool_calls[0].id = 'tool_call_10'
|
||||
mock_response.choices[0].message.tool_calls[0].function = Mock()
|
||||
mock_response.choices[0].message.tool_calls[0].function.name = 'invalid_tool'
|
||||
mock_response.choices[0].message.tool_calls[0].function.arguments = '{}'
|
||||
|
||||
with pytest.raises(FunctionCallNotExistsError):
|
||||
response_to_action(mock_response)
|
||||
|
||||
|
||||
def test_step(mock_state: State):
|
||||
# Mock the LLM response
|
||||
mock_response = Mock()
|
||||
mock_response.id = 'mock_id'
|
||||
mock_response.total_calls_in_response = 1
|
||||
mock_response.choices = [Mock()]
|
||||
mock_response.choices[0].message = Mock()
|
||||
mock_response.choices[0].message.content = 'Task completed'
|
||||
mock_response.choices[0].message.tool_calls = []
|
||||
|
||||
llm = Mock()
|
||||
llm.completion = Mock(return_value=mock_response)
|
||||
llm.is_function_calling_active = Mock(return_value=True) # Enable function calling
|
||||
llm.is_caching_prompt_active = Mock(return_value=False)
|
||||
|
||||
# Create agent with mocked LLM
|
||||
config = AgentConfig()
|
||||
config.enable_prompt_extensions = False
|
||||
agent = ProxyAgent(llm=llm, config=config)
|
||||
|
||||
# Test step with no pending actions
|
||||
mock_state.latest_user_message = None
|
||||
mock_state.latest_user_message_id = None
|
||||
mock_state.latest_user_message_timestamp = None
|
||||
mock_state.latest_user_message_cause = None
|
||||
mock_state.latest_user_message_timeout = None
|
||||
mock_state.latest_user_message_llm_metrics = None
|
||||
mock_state.latest_user_message_tool_call_metadata = None
|
||||
|
||||
action = agent.step(mock_state)
|
||||
assert isinstance(action, MessageAction)
|
||||
assert action.content == 'Task completed'
|
||||
Reference in New Issue
Block a user