mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-07 22:14:03 -05:00
feat(backend): New "update microagent prompt" API (#8357)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ import {
|
||||
GetTrajectoryResponse,
|
||||
GitChangeDiff,
|
||||
GitChange,
|
||||
GetMicroagentPromptResponse,
|
||||
} from "./open-hands.types";
|
||||
import { openHands } from "./open-hands-axios";
|
||||
import { ApiSettings, PostApiSettings, Provider } from "#/types/settings";
|
||||
@@ -393,6 +394,20 @@ class OpenHands {
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static async getMicroagentPrompt(
|
||||
conversationId: string,
|
||||
eventId: number,
|
||||
): Promise<string> {
|
||||
const { data } = await openHands.get<GetMicroagentPromptResponse>(
|
||||
`/api/conversations/${conversationId}/remember_prompt`,
|
||||
{
|
||||
params: { event_id: eventId },
|
||||
},
|
||||
);
|
||||
|
||||
return data.prompt;
|
||||
}
|
||||
}
|
||||
|
||||
export default OpenHands;
|
||||
|
||||
@@ -102,3 +102,8 @@ export interface GitChangeDiff {
|
||||
modified: string;
|
||||
original: string;
|
||||
}
|
||||
|
||||
export interface GetMicroagentPromptResponse {
|
||||
status: string;
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
12
frontend/src/hooks/query/use-get-microagent-prompt.ts
Normal file
12
frontend/src/hooks/query/use-get-microagent-prompt.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useConversationId } from "../use-conversation-id";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
export const useGetMicroagentPrompt = ({ eventId }: { eventId: number }) => {
|
||||
const { conversationId } = useConversationId();
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["conversation", "remember_prompt", conversationId, eventId],
|
||||
queryFn: () => OpenHands.getMicroagentPrompt(conversationId, eventId),
|
||||
});
|
||||
};
|
||||
21
openhands/microagent/prompts/generate_remember_prompt.j2
Normal file
21
openhands/microagent/prompts/generate_remember_prompt.j2
Normal file
@@ -0,0 +1,21 @@
|
||||
You are tasked with generating a prompt that will be used by another AI to update a special reference file. This file contains important information and learnings that are used to carry out certain tasks. The file can be extended over time to incorporate new knowledge and experiences.
|
||||
|
||||
You have been provided with a subset of new events that may require updates to the special file. These events are:
|
||||
<events>
|
||||
{{ events }}
|
||||
</events>
|
||||
|
||||
Your task is to analyze these events and determine what updates, if any, should be made to the special file. Then, you need to generate a prompt that will instruct another AI to make these updates correctly and efficiently.
|
||||
|
||||
When creating your prompt, follow these guidelines:
|
||||
1. Clearly specify which parts of the file need to be updated or if new sections should be added.
|
||||
2. Provide context for why these updates are necessary based on the new events.
|
||||
3. Be specific about the information that should be added or modified.
|
||||
4. Maintain the existing structure and formatting of the file.
|
||||
5. Ensure that the updates are consistent with the current content and don't contradict existing information.
|
||||
|
||||
Now, based on the new events provided, generate a prompt that will guide the AI in making the appropriate updates to the special file. Your prompt should be clear, specific, and actionable. Include your prompt within <update_prompt> tags.
|
||||
|
||||
<update_prompt>
|
||||
|
||||
</update_prompt>
|
||||
@@ -1,11 +1,26 @@
|
||||
import itertools
|
||||
import re
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
NullObservation,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
@@ -16,12 +31,15 @@ from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.shared import (
|
||||
@@ -35,10 +53,11 @@ from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_secrets,
|
||||
get_user_settings_store,
|
||||
get_user_settings,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
@@ -47,6 +66,7 @@ from openhands.storage.data_models.conversation_metadata import (
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
@@ -271,6 +291,76 @@ async def delete_conversation(
|
||||
return True
|
||||
|
||||
|
||||
@app.get('/conversations/{conversation_id}/remember_prompt')
|
||||
async def get_prompt(
|
||||
event_id: int,
|
||||
user_settings: SettingsStore = Depends(get_user_settings_store),
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object)
|
||||
):
|
||||
if conversation is None:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={'error': 'Conversation not found.'},
|
||||
)
|
||||
|
||||
# get event stream for the conversation
|
||||
event_stream = conversation.event_stream
|
||||
|
||||
# retrieve the relevant events
|
||||
stringified_events = _get_contextual_events(event_stream, event_id)
|
||||
|
||||
# generate a prompt
|
||||
settings = await user_settings.load()
|
||||
if settings is None:
|
||||
# placeholder for error handling
|
||||
raise ValueError('Settings not found')
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
prompt_template = generate_prompt_template(stringified_events)
|
||||
prompt = generate_prompt(llm_config, prompt_template)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'status': 'success',
|
||||
'prompt': prompt,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_prompt_template(events: str) -> str:
|
||||
env = Environment(loader=FileSystemLoader('openhands/microagent/prompts'))
|
||||
template = env.get_template('generate_remember_prompt.j2')
|
||||
return template.render(events=events)
|
||||
|
||||
|
||||
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
||||
llm = LLM(llm_config)
|
||||
messages = [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': prompt_template,
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Please generate a prompt for the AI to update the special file based on the events provided.',
|
||||
},
|
||||
]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
raw_prompt = response['choices'][0]['message']['content'].strip()
|
||||
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
||||
|
||||
if prompt:
|
||||
return prompt.group(1).strip()
|
||||
else:
|
||||
raise ValueError('No valid prompt found in the response.')
|
||||
|
||||
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
num_connections: int,
|
||||
@@ -411,3 +501,38 @@ async def stop_conversation(
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
|
||||
# find the specified events to learn from
|
||||
# Get X events around the target event
|
||||
context_size = 4
|
||||
|
||||
agent_event_filter = EventFilter(
|
||||
exclude_hidden=True,
|
||||
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
|
||||
),
|
||||
) # the types of events that can be in an agent's history
|
||||
|
||||
# from event_id - context_size to event_id..
|
||||
context_before = event_stream.search_events(
|
||||
start_id=event_id,
|
||||
filter=agent_event_filter,
|
||||
reverse=True,
|
||||
limit=context_size,
|
||||
)
|
||||
|
||||
# from event_id to event_id + context_size + 1
|
||||
context_after = event_stream.search_events(
|
||||
start_id=event_id + 1,
|
||||
filter=agent_event_filter,
|
||||
limit=context_size + 1,
|
||||
)
|
||||
|
||||
# context_before is in reverse chronological order, so convert to list and reverse it.
|
||||
ordered_context_before = list(context_before)
|
||||
ordered_context_before.reverse()
|
||||
|
||||
all_events = itertools.chain(ordered_context_before, context_after)
|
||||
stringified_events = '\n'.join(str(event) for event in all_events)
|
||||
return stringified_events
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -18,6 +20,15 @@ async def get_conversation_store(request: Request) -> ConversationStore | None:
|
||||
return conversation_store
|
||||
|
||||
|
||||
async def generate_unique_conversation_id(
|
||||
conversation_store: ConversationStore,
|
||||
) -> str:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
while await conversation_store.exists(conversation_id):
|
||||
conversation_id = uuid.uuid4().hex
|
||||
return conversation_id
|
||||
|
||||
|
||||
async def get_conversation(
|
||||
conversation_id: str, user_id: str | None = Depends(get_user_id)
|
||||
):
|
||||
|
||||
548
tests/unit/test_contextual_events.py
Normal file
548
tests/unit/test_contextual_events.py
Normal file
@@ -0,0 +1,548 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.event_filter import (
|
||||
EventFilter, # Needed for ANY matcher type check
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.server.routes.manage_conversations import _get_contextual_events
|
||||
|
||||
|
||||
# Helper to create event instances for testing, inspired by test_agent_history.py
|
||||
def create_test_events(event_specs: list[dict]) -> list[Event]:
|
||||
events = []
|
||||
for spec in event_specs:
|
||||
event_type = spec['type']
|
||||
# Attributes for the constructor
|
||||
kwargs = {
|
||||
k: v
|
||||
for k, v in spec.items()
|
||||
if k not in ['type', 'id', 'source', 'hidden', 'cause']
|
||||
}
|
||||
|
||||
# Provide default values for required fields if not in spec, to ensure instantiation
|
||||
if event_type == MessageAction and 'content' not in kwargs:
|
||||
kwargs['content'] = f'default_content_for_{spec["id"]}'
|
||||
elif event_type == CmdRunAction and 'command' not in kwargs:
|
||||
kwargs['command'] = f'default_command_for_{spec["id"]}'
|
||||
elif event_type == CmdOutputObservation:
|
||||
if 'content' not in kwargs:
|
||||
kwargs['content'] = f'default_obs_content_for_{spec["id"]}'
|
||||
if 'command_id' not in kwargs:
|
||||
kwargs['command_id'] = spec.get(
|
||||
'cause', spec['id'] - 1 if spec['id'] > 0 else 0
|
||||
) # Simplistic default
|
||||
if 'command' not in kwargs:
|
||||
kwargs['command'] = f'default_cmd_for_obs_{spec["id"]}'
|
||||
elif event_type == NullAction:
|
||||
assert 'content' not in kwargs
|
||||
elif event_type == NullObservation:
|
||||
kwargs['content'] = ''
|
||||
elif event_type == ChangeAgentStateAction:
|
||||
if 'agent_state' not in kwargs:
|
||||
kwargs['agent_state'] = 'running'
|
||||
if 'thought' not in kwargs:
|
||||
kwargs['thought'] = ''
|
||||
# 'content' for ChangeAgentStateAction is auto-generated by its message property
|
||||
elif event_type == AgentStateChangedObservation:
|
||||
if 'agent_state' not in kwargs:
|
||||
kwargs['agent_state'] = 'running'
|
||||
# 'content' for AgentStateChangedObservation is auto-generated by its message property
|
||||
|
||||
event = event_type(**kwargs)
|
||||
|
||||
# Set internal attributes after instantiation
|
||||
event._id = spec['id']
|
||||
# Default source based on type, can be overridden by spec
|
||||
default_source = (
|
||||
EventSource.AGENT if issubclass(event_type, Action) else EventSource.USER
|
||||
)
|
||||
event._source = spec.get('source', default_source)
|
||||
event._hidden = spec.get('hidden', False)
|
||||
if 'cause' in spec:
|
||||
event._cause = spec['cause']
|
||||
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def test_get_contextual_events_basic_retrieval():
|
||||
"""
|
||||
Tests basic retrieval of events, ensuring correct count, order, and string formatting.
|
||||
All events in this test are of types that are NOT filtered out by default.
|
||||
"""
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
target_event_id = 5
|
||||
context_size = 4 # Hardcoded in _get_contextual_events
|
||||
|
||||
# Define all events that *could* be in the stream for this test
|
||||
all_event_specs = [
|
||||
{'id': 1, 'type': MessageAction, 'content': 'message_1'},
|
||||
{'id': 2, 'type': CmdRunAction, 'command': 'command_2'},
|
||||
{
|
||||
'id': 3,
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'observation_3',
|
||||
'command_id': 2,
|
||||
'command': 'command_2',
|
||||
},
|
||||
{'id': 4, 'type': MessageAction, 'content': 'message_4'},
|
||||
{'id': 5, 'type': CmdRunAction, 'command': 'command_5_target'}, # Target Event
|
||||
{
|
||||
'id': 6,
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'observation_6',
|
||||
'command_id': 5,
|
||||
'command': 'command_5_target',
|
||||
},
|
||||
{'id': 7, 'type': MessageAction, 'content': 'message_7'},
|
||||
{'id': 8, 'type': CmdRunAction, 'command': 'command_8'},
|
||||
{
|
||||
'id': 9,
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'observation_9',
|
||||
'command_id': 8,
|
||||
'command': 'command_8',
|
||||
},
|
||||
{'id': 10, 'type': MessageAction, 'content': 'message_10'},
|
||||
{
|
||||
'id': 11,
|
||||
'type': CmdRunAction,
|
||||
'command': 'command_11',
|
||||
}, # Extra event, should not be included in 'after' due to limit
|
||||
]
|
||||
all_events_objects = create_test_events(all_event_specs)
|
||||
|
||||
# Map IDs to objects for easy lookup
|
||||
events_by_id = {e.id: e for e in all_events_objects}
|
||||
|
||||
# Define what search_events should return for the "before" call
|
||||
# (event_id=5, limit=4, reverse=True) -> expects [5, 4, 3, 2]
|
||||
events_to_return_before = [
|
||||
events_by_id[5],
|
||||
events_by_id[4],
|
||||
events_by_id[3],
|
||||
events_by_id[2],
|
||||
]
|
||||
|
||||
# Define what search_events should return for the "after" call
|
||||
# (start_id=6, limit=5) -> expects [6, 7, 8, 9, 10]
|
||||
events_to_return_after = [
|
||||
events_by_id[6],
|
||||
events_by_id[7],
|
||||
events_by_id[8],
|
||||
events_by_id[9],
|
||||
events_by_id[10],
|
||||
]
|
||||
|
||||
mock_event_stream.search_events.side_effect = [
|
||||
events_to_return_before,
|
||||
events_to_return_after,
|
||||
]
|
||||
|
||||
result_str = _get_contextual_events(mock_event_stream, target_event_id)
|
||||
|
||||
# Expected final list of events after processing (chronological order):
|
||||
# [event_obj_2, event_obj_3, event_obj_4, event_obj_5, (from before, reversed)
|
||||
# event_obj_6, event_obj_7, event_obj_8, event_obj_9, event_obj_10 (from after)]
|
||||
expected_final_event_objects = [
|
||||
events_by_id[2],
|
||||
events_by_id[3],
|
||||
events_by_id[4],
|
||||
events_by_id[5],
|
||||
events_by_id[6],
|
||||
events_by_id[7],
|
||||
events_by_id[8],
|
||||
events_by_id[9],
|
||||
events_by_id[10],
|
||||
]
|
||||
|
||||
# The output string is joined by newlines, using event.__str__
|
||||
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
|
||||
|
||||
assert result_str == expected_output_str
|
||||
|
||||
# Check calls to search_events
|
||||
calls = mock_event_stream.search_events.call_args_list
|
||||
assert len(calls) == 2
|
||||
|
||||
# Call 1: Before events
|
||||
args_before, kwargs_before = calls[0]
|
||||
assert kwargs_before['start_id'] == target_event_id
|
||||
assert isinstance(kwargs_before['filter'], EventFilter)
|
||||
assert kwargs_before['reverse'] is True
|
||||
assert kwargs_before['limit'] == context_size
|
||||
|
||||
# Call 2: After events
|
||||
args_after, kwargs_after = calls[1]
|
||||
assert kwargs_after['start_id'] == target_event_id + 1
|
||||
assert isinstance(kwargs_after['filter'], EventFilter)
|
||||
assert (
|
||||
'reverse' not in kwargs_after or kwargs_after['reverse'] is False
|
||||
) # default is False
|
||||
assert kwargs_after['limit'] == context_size + 1
|
||||
|
||||
|
||||
def test_get_contextual_events_filtering():
|
||||
"""
|
||||
Tests that specified event types and hidden events are filtered out.
|
||||
"""
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
target_event_id = 3 # Target a non-filtered event
|
||||
|
||||
all_event_specs = [
|
||||
# Before target_event_id = 3. Context size 4. Search limit 4.
|
||||
# search_events(start_id=3, reverse=True, limit=4)
|
||||
{'id': 0, 'type': NullAction}, # Filtered
|
||||
{'id': 1, 'type': MessageAction, 'content': 'message_1_VISIBLE'}, # Visible
|
||||
{
|
||||
'id': 2,
|
||||
'type': ChangeAgentStateAction,
|
||||
'agent_state': 'thinking',
|
||||
'thought': 'abc_FILTERED',
|
||||
}, # Filtered
|
||||
{
|
||||
'id': 3,
|
||||
'type': CmdRunAction,
|
||||
'command': 'command_3_TARGET_VISIBLE',
|
||||
}, # Target, Visible
|
||||
# After target_event_id = 3. Context size 4 + 1 = 5. Search limit 5.
|
||||
# search_events(start_id=4, limit=5)
|
||||
{
|
||||
'id': 4,
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs_4_HIDDEN_FILTERED',
|
||||
'command_id': 3,
|
||||
'hidden': True,
|
||||
}, # Filtered (hidden)
|
||||
{
|
||||
'id': 5,
|
||||
'type': AgentStateChangedObservation,
|
||||
'agent_state': 'running',
|
||||
'content': 'state_change_5_FILTERED',
|
||||
}, # Filtered
|
||||
{'id': 6, 'type': MessageAction, 'content': 'message_6_VISIBLE'}, # Visible
|
||||
{
|
||||
'id': 7,
|
||||
'type': NullObservation,
|
||||
'content': 'null_obs_7_FILTERED',
|
||||
}, # Filtered
|
||||
{'id': 8, 'type': CmdRunAction, 'command': 'command_8_VISIBLE'}, # Visible
|
||||
{
|
||||
'id': 9,
|
||||
'type': MessageAction,
|
||||
'content': 'message_9_VISIBLE',
|
||||
}, # Visible (within limit of 5 for 'after' search)
|
||||
{
|
||||
'id': 10,
|
||||
'type': MessageAction,
|
||||
'content': 'message_10_EXTRA',
|
||||
}, # Extra, should not be fetched by 'after' search
|
||||
]
|
||||
all_events_objects = create_test_events(all_event_specs)
|
||||
events_by_id = {e.id: e for e in all_events_objects}
|
||||
|
||||
# Expected events to be returned by search_events AFTER internal filtering by EventFilter
|
||||
# For "before" call (start_id=3, reverse=True, limit=4):
|
||||
# Raw available before/incl target: [cmd3, state2_filt, msg1, null0_filt]
|
||||
# After EventFilter: [cmd3, msg1] -> search_events should return these
|
||||
[events_by_id[3], events_by_id[1]]
|
||||
|
||||
# For "after" call (start_id=4, limit=5):
|
||||
# Raw available after target: [hidden_obs4_filt, agent_state5_filt, msg6, null_obs7_filt, cmd8, msg9, msg10_extra]
|
||||
# After EventFilter: [msg6, cmd8, msg9, msg10_extra] -> search_events should return first 5 of these if available
|
||||
# Limit is 5, so it should return [msg6, cmd8, msg9] (msg10_extra is out of original context_size+1 scope)
|
||||
# Correcting this: the mock search_events should simulate what EventStream.search_events does.
|
||||
# EventStream.search_events applies the filter internally.
|
||||
# So, the lists passed to side_effect should be the *already filtered* lists.
|
||||
|
||||
# Simulating EventStream.search_events behavior:
|
||||
# It iterates, applies filter, then takes limit.
|
||||
|
||||
# Before: start_id=3, reverse=True, limit=4. Candidates: [3,2,1,0]. Filtered: [3,1]. Result: [3,1]
|
||||
simulated_search_before = [events_by_id[3], events_by_id[1]]
|
||||
|
||||
# After: start_id=4, limit=5. Candidates: [4,5,6,7,8,9,10]. Filtered: [6,8,9]. Result: [6,8,9]
|
||||
simulated_search_after = [events_by_id[6], events_by_id[8], events_by_id[9]]
|
||||
|
||||
mock_event_stream.search_events.side_effect = [
|
||||
simulated_search_before,
|
||||
simulated_search_after,
|
||||
]
|
||||
|
||||
result_str = _get_contextual_events(mock_event_stream, target_event_id)
|
||||
|
||||
expected_final_event_objects = [
|
||||
events_by_id[1], # from before, reversed
|
||||
events_by_id[3], # from before, reversed (target)
|
||||
events_by_id[6], # from after
|
||||
events_by_id[8], # from after
|
||||
events_by_id[9], # from after
|
||||
]
|
||||
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
|
||||
|
||||
assert result_str == expected_output_str
|
||||
|
||||
# Verify the EventFilter used in search_events
|
||||
calls = mock_event_stream.search_events.call_args_list
|
||||
assert len(calls) == 2
|
||||
|
||||
expected_filtered_types = (
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
# Check filter for "before" call
|
||||
filter_before = calls[0][1]['filter'] # kwargs['filter']
|
||||
assert isinstance(filter_before, EventFilter)
|
||||
assert filter_before.exclude_hidden is True
|
||||
assert set(filter_before.exclude_types) == set(expected_filtered_types)
|
||||
|
||||
# Check filter for "after" call
|
||||
filter_after = calls[1][1]['filter'] # kwargs['filter']
|
||||
assert isinstance(filter_after, EventFilter)
|
||||
assert filter_after.exclude_hidden is True
|
||||
assert set(filter_after.exclude_types) == set(expected_filtered_types)
|
||||
|
||||
|
||||
def test_get_contextual_events_target_at_beginning():
|
||||
"""
|
||||
Tests behavior when the target event_id is at the beginning of the stream,
|
||||
resulting in fewer than context_size events before it.
|
||||
"""
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
target_event_id = 1 # Target is the second event (IDs are 0-indexed in list, 1-indexed for events)
|
||||
context_size = 4
|
||||
|
||||
all_event_specs = [
|
||||
{'id': 0, 'type': MessageAction, 'content': 'message_0_first'},
|
||||
{'id': 1, 'type': CmdRunAction, 'command': 'command_1_TARGET'}, # Target
|
||||
{'id': 2, 'type': CmdOutputObservation, 'content': 'obs_2', 'command_id': 1},
|
||||
{'id': 3, 'type': MessageAction, 'content': 'message_3'},
|
||||
{'id': 4, 'type': CmdRunAction, 'command': 'command_4'},
|
||||
{'id': 5, 'type': CmdOutputObservation, 'content': 'obs_5', 'command_id': 4},
|
||||
{
|
||||
'id': 6,
|
||||
'type': MessageAction,
|
||||
'content': 'message_6',
|
||||
}, # Should be fetched by 'after'
|
||||
]
|
||||
all_events_objects = create_test_events(all_event_specs)
|
||||
events_by_id = {e.id: e for e in all_events_objects}
|
||||
|
||||
# Before: start_id=1, reverse=True, limit=4. Candidates: [1,0]. Filtered: [1,0]. Result: [1,0]
|
||||
simulated_search_before = [events_by_id[1], events_by_id[0]]
|
||||
|
||||
# After: start_id=2, limit=5. Candidates: [2,3,4,5,6]. Filtered: [2,3,4,5,6]. Result: [2,3,4,5,6]
|
||||
simulated_search_after = [
|
||||
events_by_id[2],
|
||||
events_by_id[3],
|
||||
events_by_id[4],
|
||||
events_by_id[5],
|
||||
events_by_id[6],
|
||||
]
|
||||
|
||||
mock_event_stream.search_events.side_effect = [
|
||||
simulated_search_before,
|
||||
simulated_search_after,
|
||||
]
|
||||
|
||||
result_str = _get_contextual_events(mock_event_stream, target_event_id)
|
||||
|
||||
# Expected final: [event_obj_0, event_obj_1] (from before, reversed)
|
||||
# + [event_obj_2, event_obj_3, event_obj_4, event_obj_5, event_obj_6] (from after)
|
||||
expected_final_event_objects = [
|
||||
events_by_id[0],
|
||||
events_by_id[1],
|
||||
events_by_id[2],
|
||||
events_by_id[3],
|
||||
events_by_id[4],
|
||||
events_by_id[5],
|
||||
events_by_id[6],
|
||||
]
|
||||
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
|
||||
|
||||
assert result_str == expected_output_str
|
||||
|
||||
calls = mock_event_stream.search_events.call_args_list
|
||||
assert len(calls) == 2
|
||||
# Call 1: Before events
|
||||
kwargs_before = calls[0][1]
|
||||
assert kwargs_before['start_id'] == target_event_id
|
||||
assert kwargs_before['limit'] == context_size
|
||||
# Call 2: After events
|
||||
kwargs_after = calls[1][1]
|
||||
assert kwargs_after['start_id'] == target_event_id + 1
|
||||
assert kwargs_after['limit'] == context_size + 1
|
||||
|
||||
|
||||
def test_get_contextual_events_target_at_end():
|
||||
"""
|
||||
Tests behavior when the target event_id is at the end of the stream,
|
||||
resulting in fewer than context_size + 1 events after it.
|
||||
"""
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
target_event_id = 5 # Target is near the end
|
||||
context_size = 4
|
||||
|
||||
all_event_specs = [
|
||||
{'id': 0, 'type': MessageAction, 'content': 'message_0'},
|
||||
{'id': 1, 'type': CmdRunAction, 'command': 'command_1'},
|
||||
{
|
||||
'id': 2,
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs_2',
|
||||
'command_id': 1,
|
||||
}, # Fetched by 'before'
|
||||
{'id': 3, 'type': MessageAction, 'content': 'message_3'}, # Fetched by 'before'
|
||||
{'id': 4, 'type': CmdRunAction, 'command': 'command_4'}, # Fetched by 'before'
|
||||
{
|
||||
'id': 5,
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs_5_TARGET',
|
||||
'command_id': 4,
|
||||
}, # Target, Fetched by 'before'
|
||||
{
|
||||
'id': 6,
|
||||
'type': MessageAction,
|
||||
'content': 'message_6_last',
|
||||
}, # Last event, fetched by 'after'
|
||||
]
|
||||
all_events_objects = create_test_events(all_event_specs)
|
||||
events_by_id = {e.id: e for e in all_events_objects}
|
||||
|
||||
# Before: start_id=5, reverse=True, limit=4.
|
||||
# Candidates (reverse chronological from stream): [5,4,3,2,1,0]
|
||||
# search_events (after its internal filtering, assuming all visible) should return: [events_by_id[5], events_by_id[4], events_by_id[3], events_by_id[2]]
|
||||
simulated_search_before = [
|
||||
events_by_id[5],
|
||||
events_by_id[4],
|
||||
events_by_id[3],
|
||||
events_by_id[2],
|
||||
]
|
||||
|
||||
# After: start_id=6, limit=context_size + 1 = 5.
|
||||
# Candidates from stream: [events_by_id[6]]
|
||||
# search_events (after its internal filtering) should return: [events_by_id[6]]
|
||||
simulated_search_after = [events_by_id[6]]
|
||||
|
||||
mock_event_stream.search_events.side_effect = [
|
||||
simulated_search_before,
|
||||
simulated_search_after,
|
||||
]
|
||||
|
||||
result_str = _get_contextual_events(mock_event_stream, target_event_id)
|
||||
|
||||
# Expected final:
|
||||
# From 'before' (reversed): [event_obj_2, event_obj_3, event_obj_4, event_obj_5]
|
||||
# From 'after': [event_obj_6]
|
||||
expected_final_event_objects = [
|
||||
events_by_id[2],
|
||||
events_by_id[3],
|
||||
events_by_id[4],
|
||||
events_by_id[5],
|
||||
events_by_id[6],
|
||||
]
|
||||
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
|
||||
|
||||
assert result_str == expected_output_str
|
||||
|
||||
calls = mock_event_stream.search_events.call_args_list
|
||||
assert len(calls) == 2
|
||||
kwargs_before = calls[0][1]
|
||||
assert kwargs_before['start_id'] == target_event_id
|
||||
assert kwargs_before['limit'] == context_size # context_size for before
|
||||
kwargs_after = calls[1][1]
|
||||
assert kwargs_after['start_id'] == target_event_id + 1
|
||||
assert kwargs_after['limit'] == context_size + 1 # context_size + 1 for after
|
||||
|
||||
|
||||
def test_get_contextual_events_empty_search_results():
|
||||
"""
|
||||
Tests behavior when search_events returns empty lists for before and after.
|
||||
"""
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
target_event_id = 10
|
||||
context_size = 4
|
||||
|
||||
# search_events will return empty lists
|
||||
simulated_search_before = []
|
||||
simulated_search_after = []
|
||||
|
||||
mock_event_stream.search_events.side_effect = [
|
||||
simulated_search_before,
|
||||
simulated_search_after,
|
||||
]
|
||||
|
||||
result_str = _get_contextual_events(mock_event_stream, target_event_id)
|
||||
|
||||
expected_output_str = '' # Empty string as no events are found
|
||||
|
||||
assert result_str == expected_output_str
|
||||
|
||||
calls = mock_event_stream.search_events.call_args_list
|
||||
assert len(calls) == 2
|
||||
kwargs_before = calls[0][1]
|
||||
assert kwargs_before['start_id'] == target_event_id
|
||||
assert kwargs_before['limit'] == context_size
|
||||
kwargs_after = calls[1][1]
|
||||
assert kwargs_after['start_id'] == target_event_id + 1
|
||||
assert kwargs_after['limit'] == context_size + 1
|
||||
|
||||
|
||||
def test_get_contextual_events_all_events_filtered():
|
||||
"""
|
||||
Tests behavior when all events in the context window are of types
|
||||
that should be filtered out.
|
||||
"""
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
target_event_id = (
|
||||
2 # Target event itself might be filtered or not, doesn't matter for this test
|
||||
)
|
||||
|
||||
# All events are of types that should be filtered by the default filter in _get_contextual_events
|
||||
# create_test_events(all_event_specs) # Not strictly needed as search_events will return []
|
||||
|
||||
# search_events, after applying the internal EventFilter, will return empty lists
|
||||
simulated_search_before = []
|
||||
simulated_search_after = []
|
||||
|
||||
mock_event_stream.search_events.side_effect = [
|
||||
simulated_search_before,
|
||||
simulated_search_after,
|
||||
]
|
||||
|
||||
result_str = _get_contextual_events(mock_event_stream, target_event_id)
|
||||
|
||||
expected_output_str = '' # Empty string as all events are filtered
|
||||
|
||||
assert result_str == expected_output_str
|
||||
|
||||
calls = mock_event_stream.search_events.call_args_list
|
||||
assert len(calls) == 2 # Still called twice
|
||||
|
||||
# Check the filter properties on one of the calls (they should be identical)
|
||||
filter_used = calls[0][1]['filter']
|
||||
expected_filtered_types = (
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
assert isinstance(filter_used, EventFilter)
|
||||
assert filter_used.exclude_hidden is True
|
||||
assert set(filter_used.exclude_types) == set(expected_filtered_types)
|
||||
Reference in New Issue
Block a user