feat(backend): New "update microagent prompt" API (#8357)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
This commit is contained in:
sp.wack
2025-06-10 22:10:55 +04:00
committed by GitHub
parent 07862c32cb
commit dca9c7bdc6
7 changed files with 738 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ import {
GetTrajectoryResponse,
GitChangeDiff,
GitChange,
GetMicroagentPromptResponse,
} from "./open-hands.types";
import { openHands } from "./open-hands-axios";
import { ApiSettings, PostApiSettings, Provider } from "#/types/settings";
@@ -393,6 +394,20 @@ class OpenHands {
return data;
}
static async getMicroagentPrompt(
conversationId: string,
eventId: number,
): Promise<string> {
const { data } = await openHands.get<GetMicroagentPromptResponse>(
`/api/conversations/${conversationId}/remember_prompt`,
{
params: { event_id: eventId },
},
);
return data.prompt;
}
}
export default OpenHands;

View File

@@ -102,3 +102,8 @@ export interface GitChangeDiff {
modified: string;
original: string;
}
export interface GetMicroagentPromptResponse {
status: string;
prompt: string;
}

View File

@@ -0,0 +1,12 @@
import { useQuery } from "@tanstack/react-query";
import { useConversationId } from "../use-conversation-id";
import OpenHands from "#/api/open-hands";
export const useGetMicroagentPrompt = ({ eventId }: { eventId: number }) => {
const { conversationId } = useConversationId();
return useQuery({
queryKey: ["conversation", "remember_prompt", conversationId, eventId],
queryFn: () => OpenHands.getMicroagentPrompt(conversationId, eventId),
});
};

View File

@@ -0,0 +1,21 @@
You are tasked with generating a prompt that will be used by another AI to update a special reference file. This file contains important information and learnings that are used to carry out certain tasks. The file can be extended over time to incorporate new knowledge and experiences.
You have been provided with a subset of new events that may require updates to the special file. These events are:
<events>
{{ events }}
</events>
Your task is to analyze these events and determine what updates, if any, should be made to the special file. Then, you need to generate a prompt that will instruct another AI to make these updates correctly and efficiently.
When creating your prompt, follow these guidelines:
1. Clearly specify which parts of the file need to be updated or if new sections should be added.
2. Provide context for why these updates are necessary based on the new events.
3. Be specific about the information that should be added or modified.
4. Maintain the existing structure and formatting of the file.
5. Ensure that the updates are consistent with the current content and don't contradict existing information.
Now, based on the new events provided, generate a prompt that will guide the AI in making the appropriate updates to the special file. Your prompt should be clear, specific, and actionable. Include your prompt within <update_prompt> tags.
<update_prompt>
</update_prompt>

View File

@@ -1,11 +1,26 @@
import itertools
import re
import os
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from jinja2 import Environment, FileSystemLoader
from pydantic import BaseModel, Field
from openhands.events.event_filter import EventFilter
from openhands.events.stream import EventStream
from openhands.events.action import (
ChangeAgentStateAction,
NullAction,
)
from openhands.events.observation import (
NullObservation,
AgentStateChangedObservation,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
@@ -16,12 +31,15 @@ from openhands.integrations.service_types import (
ProviderType,
SuggestedTask,
)
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
)
from openhands.server.services.conversation_service import create_new_conversation
from openhands.server.session.conversation import ServerConversation
from openhands.server.dependencies import get_dependencies
from openhands.server.services.conversation_service import create_new_conversation
from openhands.server.shared import (
@@ -35,10 +53,11 @@ from openhands.server.user_auth import (
get_provider_tokens,
get_user_id,
get_user_secrets,
get_user_settings_store,
get_user_settings,
)
from openhands.server.user_auth.user_auth import AuthType
from openhands.server.utils import get_conversation_store
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import (
ConversationMetadata,
@@ -47,6 +66,7 @@ from openhands.storage.data_models.conversation_metadata import (
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.storage.data_models.settings import Settings
from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import wait_all
from openhands.utils.conversation_summary import get_default_conversation_title
@@ -271,6 +291,76 @@ async def delete_conversation(
return True
@app.get('/conversations/{conversation_id}/remember_prompt')
async def get_prompt(
event_id: int,
user_settings: SettingsStore = Depends(get_user_settings_store),
conversation: ServerConversation | None = Depends(get_conversation_object)
):
if conversation is None:
return JSONResponse(
status_code=404,
content={'error': 'Conversation not found.'},
)
# get event stream for the conversation
event_stream = conversation.event_stream
# retrieve the relevant events
stringified_events = _get_contextual_events(event_stream, event_id)
# generate a prompt
settings = await user_settings.load()
if settings is None:
# placeholder for error handling
raise ValueError('Settings not found')
llm_config = LLMConfig(
model=settings.llm_model,
api_key=settings.llm_api_key,
base_url=settings.llm_base_url,
)
prompt_template = generate_prompt_template(stringified_events)
prompt = generate_prompt(llm_config, prompt_template)
return JSONResponse(
{
'status': 'success',
'prompt': prompt,
}
)
def generate_prompt_template(events: str) -> str:
env = Environment(loader=FileSystemLoader('openhands/microagent/prompts'))
template = env.get_template('generate_remember_prompt.j2')
return template.render(events=events)
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
llm = LLM(llm_config)
messages = [
{
'role': 'system',
'content': prompt_template,
},
{
'role': 'user',
'content': 'Please generate a prompt for the AI to update the special file based on the events provided.',
},
]
response = llm.completion(messages=messages)
raw_prompt = response['choices'][0]['message']['content'].strip()
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
if prompt:
return prompt.group(1).strip()
else:
raise ValueError('No valid prompt found in the response.')
async def _get_conversation_info(
conversation: ConversationMetadata,
num_connections: int,
@@ -411,3 +501,38 @@ async def stop_conversation(
},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
# find the specified events to learn from
# Get X events around the target event
context_size = 4
agent_event_filter = EventFilter(
exclude_hidden=True,
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
),
) # the types of events that can be in an agent's history
# from event_id - context_size to event_id..
context_before = event_stream.search_events(
start_id=event_id,
filter=agent_event_filter,
reverse=True,
limit=context_size,
)
# from event_id to event_id + context_size + 1
context_after = event_stream.search_events(
start_id=event_id + 1,
filter=agent_event_filter,
limit=context_size + 1,
)
# context_before is in reverse chronological order, so convert to list and reverse it.
ordered_context_before = list(context_before)
ordered_context_before.reverse()
all_events = itertools.chain(ordered_context_before, context_after)
stringified_events = '\n'.join(str(event) for event in all_events)
return stringified_events

View File

@@ -1,3 +1,5 @@
import uuid
from fastapi import Depends, HTTPException, Request, status
from openhands.core.logger import openhands_logger as logger
@@ -18,6 +20,15 @@ async def get_conversation_store(request: Request) -> ConversationStore | None:
return conversation_store
async def generate_unique_conversation_id(
conversation_store: ConversationStore,
) -> str:
conversation_id = uuid.uuid4().hex
while await conversation_store.exists(conversation_id):
conversation_id = uuid.uuid4().hex
return conversation_id
async def get_conversation(
conversation_id: str, user_id: str | None = Depends(get_user_id)
):

View File

@@ -0,0 +1,548 @@
from unittest.mock import MagicMock
from openhands.events.action import (
Action,
ChangeAgentStateAction,
CmdRunAction,
MessageAction,
NullAction,
)
from openhands.events.event import Event, EventSource
from openhands.events.event_filter import (
EventFilter, # Needed for ANY matcher type check
)
from openhands.events.observation import (
AgentStateChangedObservation,
CmdOutputObservation,
NullObservation,
)
from openhands.events.stream import EventStream
from openhands.server.routes.manage_conversations import _get_contextual_events
# Helper to create event instances for testing, inspired by test_agent_history.py
def create_test_events(event_specs: list[dict]) -> list[Event]:
events = []
for spec in event_specs:
event_type = spec['type']
# Attributes for the constructor
kwargs = {
k: v
for k, v in spec.items()
if k not in ['type', 'id', 'source', 'hidden', 'cause']
}
# Provide default values for required fields if not in spec, to ensure instantiation
if event_type == MessageAction and 'content' not in kwargs:
kwargs['content'] = f'default_content_for_{spec["id"]}'
elif event_type == CmdRunAction and 'command' not in kwargs:
kwargs['command'] = f'default_command_for_{spec["id"]}'
elif event_type == CmdOutputObservation:
if 'content' not in kwargs:
kwargs['content'] = f'default_obs_content_for_{spec["id"]}'
if 'command_id' not in kwargs:
kwargs['command_id'] = spec.get(
'cause', spec['id'] - 1 if spec['id'] > 0 else 0
) # Simplistic default
if 'command' not in kwargs:
kwargs['command'] = f'default_cmd_for_obs_{spec["id"]}'
elif event_type == NullAction:
assert 'content' not in kwargs
elif event_type == NullObservation:
kwargs['content'] = ''
elif event_type == ChangeAgentStateAction:
if 'agent_state' not in kwargs:
kwargs['agent_state'] = 'running'
if 'thought' not in kwargs:
kwargs['thought'] = ''
# 'content' for ChangeAgentStateAction is auto-generated by its message property
elif event_type == AgentStateChangedObservation:
if 'agent_state' not in kwargs:
kwargs['agent_state'] = 'running'
# 'content' for AgentStateChangedObservation is auto-generated by its message property
event = event_type(**kwargs)
# Set internal attributes after instantiation
event._id = spec['id']
# Default source based on type, can be overridden by spec
default_source = (
EventSource.AGENT if issubclass(event_type, Action) else EventSource.USER
)
event._source = spec.get('source', default_source)
event._hidden = spec.get('hidden', False)
if 'cause' in spec:
event._cause = spec['cause']
events.append(event)
return events
def test_get_contextual_events_basic_retrieval():
"""
Tests basic retrieval of events, ensuring correct count, order, and string formatting.
All events in this test are of types that are NOT filtered out by default.
"""
mock_event_stream = MagicMock(spec=EventStream)
target_event_id = 5
context_size = 4 # Hardcoded in _get_contextual_events
# Define all events that *could* be in the stream for this test
all_event_specs = [
{'id': 1, 'type': MessageAction, 'content': 'message_1'},
{'id': 2, 'type': CmdRunAction, 'command': 'command_2'},
{
'id': 3,
'type': CmdOutputObservation,
'content': 'observation_3',
'command_id': 2,
'command': 'command_2',
},
{'id': 4, 'type': MessageAction, 'content': 'message_4'},
{'id': 5, 'type': CmdRunAction, 'command': 'command_5_target'}, # Target Event
{
'id': 6,
'type': CmdOutputObservation,
'content': 'observation_6',
'command_id': 5,
'command': 'command_5_target',
},
{'id': 7, 'type': MessageAction, 'content': 'message_7'},
{'id': 8, 'type': CmdRunAction, 'command': 'command_8'},
{
'id': 9,
'type': CmdOutputObservation,
'content': 'observation_9',
'command_id': 8,
'command': 'command_8',
},
{'id': 10, 'type': MessageAction, 'content': 'message_10'},
{
'id': 11,
'type': CmdRunAction,
'command': 'command_11',
}, # Extra event, should not be included in 'after' due to limit
]
all_events_objects = create_test_events(all_event_specs)
# Map IDs to objects for easy lookup
events_by_id = {e.id: e for e in all_events_objects}
# Define what search_events should return for the "before" call
# (event_id=5, limit=4, reverse=True) -> expects [5, 4, 3, 2]
events_to_return_before = [
events_by_id[5],
events_by_id[4],
events_by_id[3],
events_by_id[2],
]
# Define what search_events should return for the "after" call
# (start_id=6, limit=5) -> expects [6, 7, 8, 9, 10]
events_to_return_after = [
events_by_id[6],
events_by_id[7],
events_by_id[8],
events_by_id[9],
events_by_id[10],
]
mock_event_stream.search_events.side_effect = [
events_to_return_before,
events_to_return_after,
]
result_str = _get_contextual_events(mock_event_stream, target_event_id)
# Expected final list of events after processing (chronological order):
# [event_obj_2, event_obj_3, event_obj_4, event_obj_5, (from before, reversed)
# event_obj_6, event_obj_7, event_obj_8, event_obj_9, event_obj_10 (from after)]
expected_final_event_objects = [
events_by_id[2],
events_by_id[3],
events_by_id[4],
events_by_id[5],
events_by_id[6],
events_by_id[7],
events_by_id[8],
events_by_id[9],
events_by_id[10],
]
# The output string is joined by newlines, using event.__str__
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
assert result_str == expected_output_str
# Check calls to search_events
calls = mock_event_stream.search_events.call_args_list
assert len(calls) == 2
# Call 1: Before events
args_before, kwargs_before = calls[0]
assert kwargs_before['start_id'] == target_event_id
assert isinstance(kwargs_before['filter'], EventFilter)
assert kwargs_before['reverse'] is True
assert kwargs_before['limit'] == context_size
# Call 2: After events
args_after, kwargs_after = calls[1]
assert kwargs_after['start_id'] == target_event_id + 1
assert isinstance(kwargs_after['filter'], EventFilter)
assert (
'reverse' not in kwargs_after or kwargs_after['reverse'] is False
) # default is False
assert kwargs_after['limit'] == context_size + 1
def test_get_contextual_events_filtering():
"""
Tests that specified event types and hidden events are filtered out.
"""
mock_event_stream = MagicMock(spec=EventStream)
target_event_id = 3 # Target a non-filtered event
all_event_specs = [
# Before target_event_id = 3. Context size 4. Search limit 4.
# search_events(start_id=3, reverse=True, limit=4)
{'id': 0, 'type': NullAction}, # Filtered
{'id': 1, 'type': MessageAction, 'content': 'message_1_VISIBLE'}, # Visible
{
'id': 2,
'type': ChangeAgentStateAction,
'agent_state': 'thinking',
'thought': 'abc_FILTERED',
}, # Filtered
{
'id': 3,
'type': CmdRunAction,
'command': 'command_3_TARGET_VISIBLE',
}, # Target, Visible
# After target_event_id = 3. Context size 4 + 1 = 5. Search limit 5.
# search_events(start_id=4, limit=5)
{
'id': 4,
'type': CmdOutputObservation,
'content': 'obs_4_HIDDEN_FILTERED',
'command_id': 3,
'hidden': True,
}, # Filtered (hidden)
{
'id': 5,
'type': AgentStateChangedObservation,
'agent_state': 'running',
'content': 'state_change_5_FILTERED',
}, # Filtered
{'id': 6, 'type': MessageAction, 'content': 'message_6_VISIBLE'}, # Visible
{
'id': 7,
'type': NullObservation,
'content': 'null_obs_7_FILTERED',
}, # Filtered
{'id': 8, 'type': CmdRunAction, 'command': 'command_8_VISIBLE'}, # Visible
{
'id': 9,
'type': MessageAction,
'content': 'message_9_VISIBLE',
}, # Visible (within limit of 5 for 'after' search)
{
'id': 10,
'type': MessageAction,
'content': 'message_10_EXTRA',
}, # Extra, should not be fetched by 'after' search
]
all_events_objects = create_test_events(all_event_specs)
events_by_id = {e.id: e for e in all_events_objects}
# Expected events to be returned by search_events AFTER internal filtering by EventFilter
# For "before" call (start_id=3, reverse=True, limit=4):
# Raw available before/incl target: [cmd3, state2_filt, msg1, null0_filt]
# After EventFilter: [cmd3, msg1] -> search_events should return these
[events_by_id[3], events_by_id[1]]
# For "after" call (start_id=4, limit=5):
# Raw available after target: [hidden_obs4_filt, agent_state5_filt, msg6, null_obs7_filt, cmd8, msg9, msg10_extra]
# After EventFilter: [msg6, cmd8, msg9, msg10_extra] -> search_events should return first 5 of these if available
# Limit is 5, so it should return [msg6, cmd8, msg9] (msg10_extra is out of original context_size+1 scope)
# Correcting this: the mock search_events should simulate what EventStream.search_events does.
# EventStream.search_events applies the filter internally.
# So, the lists passed to side_effect should be the *already filtered* lists.
# Simulating EventStream.search_events behavior:
# It iterates, applies filter, then takes limit.
# Before: start_id=3, reverse=True, limit=4. Candidates: [3,2,1,0]. Filtered: [3,1]. Result: [3,1]
simulated_search_before = [events_by_id[3], events_by_id[1]]
# After: start_id=4, limit=5. Candidates: [4,5,6,7,8,9,10]. Filtered: [6,8,9]. Result: [6,8,9]
simulated_search_after = [events_by_id[6], events_by_id[8], events_by_id[9]]
mock_event_stream.search_events.side_effect = [
simulated_search_before,
simulated_search_after,
]
result_str = _get_contextual_events(mock_event_stream, target_event_id)
expected_final_event_objects = [
events_by_id[1], # from before, reversed
events_by_id[3], # from before, reversed (target)
events_by_id[6], # from after
events_by_id[8], # from after
events_by_id[9], # from after
]
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
assert result_str == expected_output_str
# Verify the EventFilter used in search_events
calls = mock_event_stream.search_events.call_args_list
assert len(calls) == 2
expected_filtered_types = (
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
)
# Check filter for "before" call
filter_before = calls[0][1]['filter'] # kwargs['filter']
assert isinstance(filter_before, EventFilter)
assert filter_before.exclude_hidden is True
assert set(filter_before.exclude_types) == set(expected_filtered_types)
# Check filter for "after" call
filter_after = calls[1][1]['filter'] # kwargs['filter']
assert isinstance(filter_after, EventFilter)
assert filter_after.exclude_hidden is True
assert set(filter_after.exclude_types) == set(expected_filtered_types)
def test_get_contextual_events_target_at_beginning():
"""
Tests behavior when the target event_id is at the beginning of the stream,
resulting in fewer than context_size events before it.
"""
mock_event_stream = MagicMock(spec=EventStream)
target_event_id = 1 # Target is the second event (IDs are 0-indexed in list, 1-indexed for events)
context_size = 4
all_event_specs = [
{'id': 0, 'type': MessageAction, 'content': 'message_0_first'},
{'id': 1, 'type': CmdRunAction, 'command': 'command_1_TARGET'}, # Target
{'id': 2, 'type': CmdOutputObservation, 'content': 'obs_2', 'command_id': 1},
{'id': 3, 'type': MessageAction, 'content': 'message_3'},
{'id': 4, 'type': CmdRunAction, 'command': 'command_4'},
{'id': 5, 'type': CmdOutputObservation, 'content': 'obs_5', 'command_id': 4},
{
'id': 6,
'type': MessageAction,
'content': 'message_6',
}, # Should be fetched by 'after'
]
all_events_objects = create_test_events(all_event_specs)
events_by_id = {e.id: e for e in all_events_objects}
# Before: start_id=1, reverse=True, limit=4. Candidates: [1,0]. Filtered: [1,0]. Result: [1,0]
simulated_search_before = [events_by_id[1], events_by_id[0]]
# After: start_id=2, limit=5. Candidates: [2,3,4,5,6]. Filtered: [2,3,4,5,6]. Result: [2,3,4,5,6]
simulated_search_after = [
events_by_id[2],
events_by_id[3],
events_by_id[4],
events_by_id[5],
events_by_id[6],
]
mock_event_stream.search_events.side_effect = [
simulated_search_before,
simulated_search_after,
]
result_str = _get_contextual_events(mock_event_stream, target_event_id)
# Expected final: [event_obj_0, event_obj_1] (from before, reversed)
# + [event_obj_2, event_obj_3, event_obj_4, event_obj_5, event_obj_6] (from after)
expected_final_event_objects = [
events_by_id[0],
events_by_id[1],
events_by_id[2],
events_by_id[3],
events_by_id[4],
events_by_id[5],
events_by_id[6],
]
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
assert result_str == expected_output_str
calls = mock_event_stream.search_events.call_args_list
assert len(calls) == 2
# Call 1: Before events
kwargs_before = calls[0][1]
assert kwargs_before['start_id'] == target_event_id
assert kwargs_before['limit'] == context_size
# Call 2: After events
kwargs_after = calls[1][1]
assert kwargs_after['start_id'] == target_event_id + 1
assert kwargs_after['limit'] == context_size + 1
def test_get_contextual_events_target_at_end():
"""
Tests behavior when the target event_id is at the end of the stream,
resulting in fewer than context_size + 1 events after it.
"""
mock_event_stream = MagicMock(spec=EventStream)
target_event_id = 5 # Target is near the end
context_size = 4
all_event_specs = [
{'id': 0, 'type': MessageAction, 'content': 'message_0'},
{'id': 1, 'type': CmdRunAction, 'command': 'command_1'},
{
'id': 2,
'type': CmdOutputObservation,
'content': 'obs_2',
'command_id': 1,
}, # Fetched by 'before'
{'id': 3, 'type': MessageAction, 'content': 'message_3'}, # Fetched by 'before'
{'id': 4, 'type': CmdRunAction, 'command': 'command_4'}, # Fetched by 'before'
{
'id': 5,
'type': CmdOutputObservation,
'content': 'obs_5_TARGET',
'command_id': 4,
}, # Target, Fetched by 'before'
{
'id': 6,
'type': MessageAction,
'content': 'message_6_last',
}, # Last event, fetched by 'after'
]
all_events_objects = create_test_events(all_event_specs)
events_by_id = {e.id: e for e in all_events_objects}
# Before: start_id=5, reverse=True, limit=4.
# Candidates (reverse chronological from stream): [5,4,3,2,1,0]
# search_events (after its internal filtering, assuming all visible) should return: [events_by_id[5], events_by_id[4], events_by_id[3], events_by_id[2]]
simulated_search_before = [
events_by_id[5],
events_by_id[4],
events_by_id[3],
events_by_id[2],
]
# After: start_id=6, limit=context_size + 1 = 5.
# Candidates from stream: [events_by_id[6]]
# search_events (after its internal filtering) should return: [events_by_id[6]]
simulated_search_after = [events_by_id[6]]
mock_event_stream.search_events.side_effect = [
simulated_search_before,
simulated_search_after,
]
result_str = _get_contextual_events(mock_event_stream, target_event_id)
# Expected final:
# From 'before' (reversed): [event_obj_2, event_obj_3, event_obj_4, event_obj_5]
# From 'after': [event_obj_6]
expected_final_event_objects = [
events_by_id[2],
events_by_id[3],
events_by_id[4],
events_by_id[5],
events_by_id[6],
]
expected_output_str = '\n'.join(str(e) for e in expected_final_event_objects)
assert result_str == expected_output_str
calls = mock_event_stream.search_events.call_args_list
assert len(calls) == 2
kwargs_before = calls[0][1]
assert kwargs_before['start_id'] == target_event_id
assert kwargs_before['limit'] == context_size # context_size for before
kwargs_after = calls[1][1]
assert kwargs_after['start_id'] == target_event_id + 1
assert kwargs_after['limit'] == context_size + 1 # context_size + 1 for after
def test_get_contextual_events_empty_search_results():
"""
Tests behavior when search_events returns empty lists for before and after.
"""
mock_event_stream = MagicMock(spec=EventStream)
target_event_id = 10
context_size = 4
# search_events will return empty lists
simulated_search_before = []
simulated_search_after = []
mock_event_stream.search_events.side_effect = [
simulated_search_before,
simulated_search_after,
]
result_str = _get_contextual_events(mock_event_stream, target_event_id)
expected_output_str = '' # Empty string as no events are found
assert result_str == expected_output_str
calls = mock_event_stream.search_events.call_args_list
assert len(calls) == 2
kwargs_before = calls[0][1]
assert kwargs_before['start_id'] == target_event_id
assert kwargs_before['limit'] == context_size
kwargs_after = calls[1][1]
assert kwargs_after['start_id'] == target_event_id + 1
assert kwargs_after['limit'] == context_size + 1
def test_get_contextual_events_all_events_filtered():
"""
Tests behavior when all events in the context window are of types
that should be filtered out.
"""
mock_event_stream = MagicMock(spec=EventStream)
target_event_id = (
2 # Target event itself might be filtered or not, doesn't matter for this test
)
# All events are of types that should be filtered by the default filter in _get_contextual_events
# create_test_events(all_event_specs) # Not strictly needed as search_events will return []
# search_events, after applying the internal EventFilter, will return empty lists
simulated_search_before = []
simulated_search_after = []
mock_event_stream.search_events.side_effect = [
simulated_search_before,
simulated_search_after,
]
result_str = _get_contextual_events(mock_event_stream, target_event_id)
expected_output_str = '' # Empty string as all events are filtered
assert result_str == expected_output_str
calls = mock_event_stream.search_events.call_args_list
assert len(calls) == 2 # Still called twice
# Check the filter properties on one of the calls (they should be identical)
filter_used = calls[0][1]['filter']
expected_filtered_types = (
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
)
assert isinstance(filter_used, EventFilter)
assert filter_used.exclude_hidden is True
assert set(filter_used.exclude_types) == set(expected_filtered_types)