mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 23:38:08 -05:00
Improve performance of LLM summarizing condenser (#6597)
Co-authored-by: Calvin Smith <calvin@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,14 @@ class LLMSummarizingCondenserConfig(BaseModel):
|
||||
llm_config: LLMConfig = Field(
|
||||
..., description='Configuration for the LLM to use for condensing.'
|
||||
)
|
||||
keep_first: int = Field(
|
||||
default=1,
|
||||
description='The number of initial events to condense.',
|
||||
ge=0,
|
||||
)
|
||||
max_size: int = Field(
|
||||
default=10, description='Maximum number of events to keep.', ge=1
|
||||
)
|
||||
|
||||
|
||||
class AmortizedForgettingCondenserConfig(BaseModel):
|
||||
|
||||
@@ -1,55 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import LLMSummarizingCondenserConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm import LLM
|
||||
from openhands.memory.condenser.condenser import Condenser
|
||||
from openhands.memory.condenser.condenser import RollingCondenser
|
||||
|
||||
|
||||
class LLMSummarizingCondenser(Condenser):
|
||||
"""A condenser that relies on a language model to summarize the event sequence as a single event."""
|
||||
class LLMSummarizingCondenser(RollingCondenser):
|
||||
"""A condenser that summarizes forgotten events.
|
||||
|
||||
def __init__(self, llm: LLM):
|
||||
Maintains a condensed history and forgets old events when it grows too large,
|
||||
keeping a special summarization event after the prefix that summarizes all previous summarizations
|
||||
and newly forgotten events.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.llm = llm
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Applies an LLM to summarize the list of events.
|
||||
"""Apply the amortized forgetting strategy with LLM summarization to the given list of events."""
|
||||
if len(events) <= self.max_size:
|
||||
return events
|
||||
|
||||
Raises:
|
||||
Exception: If the LLM is unable to summarize the event sequence.
|
||||
"""
|
||||
try:
|
||||
# Convert events to a format suitable for summarization
|
||||
events_text = '\n'.join(f'{e.timestamp}: {e.message}' for e in events)
|
||||
summarize_prompt = f'Please summarize these events:\n{events_text}'
|
||||
head = events[: self.keep_first]
|
||||
|
||||
resp = self.llm.completion(
|
||||
messages=[{'content': summarize_prompt, 'role': 'user'}]
|
||||
)
|
||||
summary_response = resp.choices[0].message.content
|
||||
target_size = self.max_size // 2
|
||||
events_from_tail = target_size - len(head)
|
||||
tail = events[-events_from_tail:]
|
||||
|
||||
# Create a new summary event with the condensed content
|
||||
summary_event = AgentCondensationObservation(summary_response)
|
||||
summary_event = (
|
||||
events[self.keep_first]
|
||||
if isinstance(events[self.keep_first], AgentCondensationObservation)
|
||||
else AgentCondensationObservation('No events summarized')
|
||||
)
|
||||
|
||||
# Add metrics to state
|
||||
self.add_metadata('response', resp.model_dump())
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
# Identify events to be forgotten (those not in head or tail)
|
||||
forgotten_events = []
|
||||
for event in events[self.keep_first : -events_from_tail]:
|
||||
if not isinstance(event, AgentCondensationObservation):
|
||||
forgotten_events.append(event)
|
||||
|
||||
return [summary_event]
|
||||
# Construct prompt for summarization
|
||||
prompt = """You are maintaining state history for an LLM-based code agent. Track:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error condensing events: {str(e)}')
|
||||
raise e
|
||||
STATE: {File paths, function signatures, data structures}
|
||||
TESTS: {Failing cases, error messages, outputs}
|
||||
CHANGES: {Code edits, variable updates}
|
||||
DEPS: {Dependencies, imports, external calls}
|
||||
INTENT: {Why changes were made, acceptance criteria}
|
||||
|
||||
SKIP: {Git clones, build logs}
|
||||
SUMMARIZE: {File listings}
|
||||
MAX_LENGTH: Keep summaries under 1000 words
|
||||
|
||||
Example history format:
|
||||
STATE: mod_float() in card.py updated
|
||||
TESTS: test_format() passed
|
||||
CHANGES: str(val) replaces f"{val:.16G}"
|
||||
DEPS: None modified
|
||||
INTENT: Fix float precision overflow"""
|
||||
|
||||
prompt + '\n\n'
|
||||
|
||||
prompt += ('\n' + summary_event.message + '\n') if summary_event.message else ''
|
||||
|
||||
prompt + '\n\n'
|
||||
|
||||
for forgotten_event in forgotten_events:
|
||||
prompt += str(forgotten_event) + '\n\n'
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{
|
||||
'content': prompt,
|
||||
'role': 'user',
|
||||
},
|
||||
],
|
||||
)
|
||||
summary = response.choices[0].message.content
|
||||
|
||||
self.add_metadata('response', response.model_dump())
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
return head + [AgentCondensationObservation(summary)] + tail
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: LLMSummarizingCondenserConfig
|
||||
) -> LLMSummarizingCondenser:
|
||||
return LLMSummarizingCondenser(llm=LLM(config=config.llm_config))
|
||||
return LLMSummarizingCondenser(
|
||||
llm=LLM(config=config.llm_config),
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
)
|
||||
|
||||
|
||||
LLMSummarizingCondenser.register_config(LLMSummarizingCondenserConfig)
|
||||
|
||||
@@ -6,7 +6,9 @@ import socketio
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
|
||||
from openhands.core.config.condenser_config import (
|
||||
LLMSummarizingCondenserConfig,
|
||||
)
|
||||
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
@@ -108,8 +110,8 @@ class Session:
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
|
||||
if settings.enable_default_condenser:
|
||||
default_condenser_config = AmortizedForgettingCondenserConfig(
|
||||
keep_first=3, max_size=20
|
||||
default_condenser_config = LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=3, max_size=40
|
||||
)
|
||||
logger.info(f'Enabling default condenser: {default_condenser_config}')
|
||||
agent_config.condenser = default_condenser_config
|
||||
|
||||
@@ -15,6 +15,7 @@ from openhands.core.config.condenser_config import (
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.llm import LLM
|
||||
from openhands.memory.condenser import Condenser
|
||||
@@ -214,47 +215,117 @@ def test_recent_events_condenser():
|
||||
assert result[2]._message == 'Event 5'
|
||||
|
||||
|
||||
def test_llm_condenser_from_config():
|
||||
"""Test that LLMCondensers can be made from config."""
|
||||
def test_llm_summarization_condenser_from_config():
|
||||
"""Test that LLMSummarizingCondenser objects can be made from config."""
|
||||
config = LLMSummarizingCondenserConfig(
|
||||
max_size=50,
|
||||
keep_first=10,
|
||||
llm_config=LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
)
|
||||
),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
|
||||
assert isinstance(condenser, LLMSummarizingCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
|
||||
def test_llm_condenser(mock_llm, mock_state):
|
||||
"""Test that LLMCondensers use the LLM to generate a summary event."""
|
||||
events = [
|
||||
create_test_event('Event 1'),
|
||||
create_test_event('Event 2'),
|
||||
]
|
||||
mock_state.history = events
|
||||
def test_llm_amortized_summarization_condenser_invalid_config():
|
||||
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=MagicMock(),
|
||||
max_size=4,
|
||||
keep_first=2,
|
||||
)
|
||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0)
|
||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1)
|
||||
|
||||
mock_llm.metrics = MagicMock()
|
||||
|
||||
def test_llm_summarizing_condenser_grows_to_max_size(mock_llm, mock_state):
|
||||
"""Test that LLMSummarizingCondenser correctly maintains an event context up to max size."""
|
||||
max_size = 15
|
||||
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
|
||||
|
||||
for i in range(max_size):
|
||||
event = create_test_event(f'Event {i}')
|
||||
mock_state.history.append(event)
|
||||
results = condenser.condensed_history(mock_state)
|
||||
assert len(results) == i + 1
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_forgets_and_summarizes(mock_llm, mock_state):
|
||||
"""Test that the LLMSummarizingCondenser forgets events and maintains a summary."""
|
||||
max_size = 4
|
||||
keep_first = 1
|
||||
condenser = LLMSummarizingCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
)
|
||||
|
||||
# Add initial event
|
||||
first_event = create_test_event('Event 0')
|
||||
mock_state.history.append(first_event)
|
||||
|
||||
# Set up mock LLM response
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
|
||||
# Add enough events to trigger forgetting
|
||||
for i in range(max_size + 3): # +3 to ensure we're well past max_size
|
||||
event = create_test_event(f'Event {i+1}')
|
||||
mock_state.history.append(event)
|
||||
|
||||
# Get the condensed history
|
||||
results = condenser.condensed_history(mock_state)
|
||||
|
||||
# We should have exactly 3 events:
|
||||
# 1. First event (keep_first = 1)
|
||||
# 2. Summary event
|
||||
# 3. Most recent event
|
||||
assert len(results) == 3, f'Expected 3 events, got {len(results)}: {results}'
|
||||
assert (
|
||||
results[0] == first_event
|
||||
), f'First event should be {first_event}, got {results[0]}'
|
||||
assert isinstance(
|
||||
results[1], AgentCondensationObservation
|
||||
), f'Second event should be a summary, got {results[1]}'
|
||||
assert (
|
||||
results[1].content == 'Summary of forgotten events'
|
||||
), f"Summary content should be 'Summary of forgotten events', got {results[1].content}"
|
||||
assert results[2] == event, f'Last event should be {event}, got {results[2]}'
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_llm_call(mock_llm, mock_state):
|
||||
"""Test that the LLM is called correctly when forgetting events."""
|
||||
max_size = 4
|
||||
keep_first = 1
|
||||
condenser = LLMSummarizingCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
)
|
||||
|
||||
# Add initial event
|
||||
first_event = create_test_event('Event 0')
|
||||
mock_state.history.append(first_event)
|
||||
|
||||
# Set up mock LLM response
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
mock_llm.metrics.get.return_value = {'test_metric': 1.0}
|
||||
|
||||
mock_llm.set_mock_response_content('Summary of events')
|
||||
# Add enough events to trigger forgetting
|
||||
for i in range(max_size):
|
||||
event = create_test_event(f'Event {i+1}')
|
||||
mock_state.history.append(event)
|
||||
condenser.condensed_history(mock_state)
|
||||
|
||||
condenser = LLMSummarizingCondenser(llm=mock_llm)
|
||||
result = condenser.condensed_history(mock_state)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == 'Summary of events'
|
||||
|
||||
# Verify LLM was called with correct prompt.
|
||||
# Verify LLM was called with correct prompt
|
||||
mock_llm.completion.assert_called_once()
|
||||
call_args = mock_llm.completion.call_args[1]
|
||||
assert 'messages' in call_args
|
||||
assert len(call_args['messages']) == 1
|
||||
assert 'Event 1' in call_args['messages'][0]['content']
|
||||
assert 'Event 2' in call_args['messages'][0]['content']
|
||||
|
||||
# Verify metrics were added to state
|
||||
assert 'condenser_meta' in mock_state.extra_data
|
||||
@@ -262,25 +333,6 @@ def test_llm_condenser(mock_llm, mock_state):
|
||||
assert mock_state.extra_data['condenser_meta'][0]['metrics'] == {'test_metric': 1.0}
|
||||
|
||||
|
||||
def test_llm_condenser_error():
|
||||
"""Test that LLM errors are propagated during condensation."""
|
||||
events = [create_test_event('Event 1', datetime(2024, 1, 1, 10, 0))]
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state.history = events
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion.side_effect = Exception('LLM error')
|
||||
|
||||
condenser = LLMSummarizingCondenser(llm=mock_llm)
|
||||
|
||||
try:
|
||||
condenser.condensed_history(mock_state)
|
||||
raise AssertionError('Expected exception was not raised.')
|
||||
except Exception as e:
|
||||
assert str(e) == 'LLM error'
|
||||
|
||||
|
||||
def test_amortized_forgetting_condenser_from_config():
|
||||
"""Test that AmortizedForgettingCondenser objects can be made from config."""
|
||||
max_size = 50
|
||||
|
||||
Reference in New Issue
Block a user