mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
fix-git-ch
...
openhands-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96360d3e7b |
@@ -1,23 +0,0 @@
|
||||
# Memory Component
|
||||
|
||||
- Short Term History
|
||||
- Memory Condenser
|
||||
- Long Term Memory
|
||||
|
||||
## Short Term History
|
||||
- Short term history filters the event stream and computes the messages that are injected into the context
|
||||
- It filters out certain events of no interest for the Agent, such as AgentChangeStateObservation or NullAction/NullObservation
|
||||
- When the context window or the token limit set by the user is exceeded, history starts condensing: chunks of messages into summaries.
|
||||
- Each summary is then injected into the context, in the place of the respective chunk it summarizes
|
||||
|
||||
## Memory Condenser
|
||||
- Memory condenser is responsible for summarizing the chunks of events
|
||||
- It summarizes the earlier events first
|
||||
- It starts with the earliest agent actions and observations between two user messages
|
||||
- Then it does the same for later chunks of events between user messages
|
||||
- If there are no more agent events, it summarizes the user messages, this time one by one, if they're large enough and not immediately after an AgentFinishAction event (we assume those are tasks, potentially important)
|
||||
- Summaries are retrieved from the LLM as AgentSummarizeAction, and are saved in State.
|
||||
|
||||
## Long Term Memory
|
||||
- Long term memory component stores embeddings for events and prompts in a vector store
|
||||
- The agent can query it when it needs detailed information about a past event or to learn new actions
|
||||
@@ -1,4 +0,0 @@
|
||||
from openhands.memory.condenser import Condenser
|
||||
from openhands.memory.long_term_memory import LongTermMemory
|
||||
|
||||
__all__ = ['LongTermMemory', 'Condenser']
|
||||
@@ -1,4 +0,0 @@
|
||||
import openhands.memory.condenser.impl # noqa F401 (we import this to get the condensers registered)
|
||||
from openhands.memory.condenser.condenser import Condenser, get_condensation_metadata
|
||||
|
||||
__all__ = ['Condenser', 'get_condensation_metadata', 'CONDENSER_REGISTRY']
|
||||
@@ -1,178 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.condenser_config import CondenserConfig
|
||||
from openhands.events.event import Event
|
||||
|
||||
CONDENSER_METADATA_KEY = 'condenser_meta'
|
||||
"""Key identifying where metadata is stored in a `State` object's `extra_data` field."""
|
||||
|
||||
|
||||
def get_condensation_metadata(state: State) -> list[dict[str, Any]]:
|
||||
"""Utility function to retrieve a list of metadata batches from a `State`.
|
||||
|
||||
Args:
|
||||
state: The state to retrieve metadata from.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of metadata batches, each representing a condensation.
|
||||
"""
|
||||
if CONDENSER_METADATA_KEY in state.extra_data:
|
||||
return state.extra_data[CONDENSER_METADATA_KEY]
|
||||
return []
|
||||
|
||||
|
||||
CONDENSER_REGISTRY: dict[type[CondenserConfig], type[Condenser]] = {}
|
||||
"""Registry of condenser configurations to their corresponding condenser classes."""
|
||||
|
||||
|
||||
class Condenser(ABC):
|
||||
"""Abstract condenser interface.
|
||||
|
||||
Condensers take a list of `Event` objects and reduce them into a potentially smaller list.
|
||||
|
||||
Agents can use condensers to reduce the amount of events they need to consider when deciding which action to take. To use a condenser, agents can call the `condensed_history` method on the current `State` being considered and use the results instead of the full history.
|
||||
|
||||
Example usage::
|
||||
|
||||
condenser = Condenser.from_config(condenser_config)
|
||||
events = condenser.condensed_history(state)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._metadata_batch: dict[str, Any] = {}
|
||||
|
||||
def add_metadata(self, key: str, value: Any) -> None:
|
||||
"""Add information to the current metadata batch.
|
||||
|
||||
Any key/value pairs added to the metadata batch will be recorded in the `State` at the end of the current condensation.
|
||||
|
||||
Args:
|
||||
key: The key to store the metadata under.
|
||||
|
||||
value: The metadata to store.
|
||||
"""
|
||||
self._metadata_batch[key] = value
|
||||
|
||||
def write_metadata(self, state: State) -> None:
|
||||
"""Write the current batch of metadata to the `State`.
|
||||
|
||||
Resets the current metadata batch: any metadata added after this call will be stored in a new batch and written to the `State` at the end of the next condensation.
|
||||
"""
|
||||
if CONDENSER_METADATA_KEY not in state.extra_data:
|
||||
state.extra_data[CONDENSER_METADATA_KEY] = []
|
||||
if self._metadata_batch:
|
||||
state.extra_data[CONDENSER_METADATA_KEY].append(self._metadata_batch)
|
||||
|
||||
# Since the batch has been written, clear it for the next condensation
|
||||
self._metadata_batch = {}
|
||||
|
||||
@contextmanager
|
||||
def metadata_batch(self, state: State):
|
||||
"""Context manager to ensure batched metadata is always written to the `State`."""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.write_metadata(state)
|
||||
|
||||
@abstractmethod
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Condense a sequence of events into a potentially smaller list.
|
||||
|
||||
New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information.
|
||||
|
||||
Args:
|
||||
events: A list of events representing the entire history of the agent.
|
||||
|
||||
Returns:
|
||||
list[Event]: An event sequence representing a condensed history of the agent.
|
||||
"""
|
||||
|
||||
def condensed_history(self, state: State) -> list[Event]:
|
||||
"""Condense the state's history."""
|
||||
with self.metadata_batch(state):
|
||||
return self.condense(state.history)
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, configuration_type: type[CondenserConfig]) -> None:
|
||||
"""Register a new condenser configuration type.
|
||||
|
||||
Instances of registered configuration types can be passed to `from_config` to create instances of the corresponding condenser.
|
||||
|
||||
Args:
|
||||
configuration_type: The type of configuration used to create instances of the condenser.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration type is already registered.
|
||||
"""
|
||||
if configuration_type in CONDENSER_REGISTRY:
|
||||
raise ValueError(
|
||||
f'Condenser configuration {configuration_type} is already registered'
|
||||
)
|
||||
CONDENSER_REGISTRY[configuration_type] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: CondenserConfig) -> Condenser:
|
||||
"""Create a condenser from a configuration object.
|
||||
|
||||
Args:
|
||||
config: Configuration for the condenser.
|
||||
|
||||
Returns:
|
||||
Condenser: A condenser instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condenser type is not recognized.
|
||||
"""
|
||||
try:
|
||||
condenser_class = CONDENSER_REGISTRY[type(config)]
|
||||
return condenser_class.from_config(config)
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown condenser config: {config}')
|
||||
|
||||
|
||||
class RollingCondenser(Condenser, ABC):
|
||||
"""Base class for a specialized condenser strategy that applies condensation to a rolling history.
|
||||
|
||||
The rolling history is computed by appending new events to the most recent condensation. For example, the sequence of calls::
|
||||
|
||||
assert state.history == [event1, event2, event3]
|
||||
condensation = condenser.condensed_history(state)
|
||||
|
||||
# ...new events are added to the state...
|
||||
|
||||
assert state.history == [event1, event2, event3, event4, event5]
|
||||
condenser.condensed_history(state)
|
||||
|
||||
will result in second call to `condensed_history` passing `condensation + [event4, event5]` to the `condense` method.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._condensation: list[Event] = []
|
||||
self._last_history_length: int = 0
|
||||
|
||||
super().__init__()
|
||||
|
||||
@override
|
||||
def condensed_history(self, state: State) -> list[Event]:
|
||||
# The history should grow monotonically -- if it doesn't, something has
|
||||
# truncated the history and we need to reset our tracking.
|
||||
if len(state.history) < self._last_history_length:
|
||||
self._condensation = []
|
||||
self._last_history_length = 0
|
||||
|
||||
new_events = state.history[self._last_history_length :]
|
||||
|
||||
with self.metadata_batch(state):
|
||||
results = self.condense(self._condensation + new_events)
|
||||
|
||||
self._condensation = results
|
||||
self._last_history_length = len(state.history)
|
||||
|
||||
return results
|
||||
@@ -1,31 +0,0 @@
|
||||
from openhands.memory.condenser.impl.amortized_forgetting_condenser import (
|
||||
AmortizedForgettingCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.browser_output_condenser import (
|
||||
BrowserOutputCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.llm_attention_condenser import (
|
||||
ImportantEventSelection,
|
||||
LLMAttentionCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.llm_summarizing_condenser import (
|
||||
LLMSummarizingCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.no_op_condenser import NoOpCondenser
|
||||
from openhands.memory.condenser.impl.observation_masking_condenser import (
|
||||
ObservationMaskingCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.recent_events_condenser import (
|
||||
RecentEventsCondenser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AmortizedForgettingCondenser',
|
||||
'LLMAttentionCondenser',
|
||||
'ImportantEventSelection',
|
||||
'LLMSummarizingCondenser',
|
||||
'NoOpCondenser',
|
||||
'ObservationMaskingCondenser',
|
||||
'BrowserOutputCondenser',
|
||||
'RecentEventsCondenser',
|
||||
]
|
||||
@@ -1,55 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.memory.condenser.condenser import RollingCondenser
|
||||
|
||||
|
||||
class AmortizedForgettingCondenser(RollingCondenser):
|
||||
"""A condenser that maintains a condensed history and forgets old events when it grows too large."""
|
||||
|
||||
def __init__(self, max_size: int = 100, keep_first: int = 0):
|
||||
"""Initialize the condenser.
|
||||
|
||||
Args:
|
||||
max_size: Maximum size of history before forgetting.
|
||||
keep_first: Number of initial events to always keep.
|
||||
|
||||
Raises:
|
||||
ValueError: If keep_first is greater than max_size, keep_first is negative, or max_size is non-positive.
|
||||
"""
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({keep_first}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Apply the amortized forgetting strategy to the given list of events."""
|
||||
if len(events) <= self.max_size:
|
||||
return events
|
||||
|
||||
target_size = self.max_size // 2
|
||||
head = events[: self.keep_first]
|
||||
|
||||
events_from_tail = target_size - len(head)
|
||||
tail = events[-events_from_tail:]
|
||||
|
||||
return head + tail
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: AmortizedForgettingCondenserConfig
|
||||
) -> AmortizedForgettingCondenser:
|
||||
return AmortizedForgettingCondenser(**config.model_dump(exclude=['type']))
|
||||
|
||||
|
||||
AmortizedForgettingCondenser.register_config(AmortizedForgettingCondenserConfig)
|
||||
@@ -1,48 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import BrowserOutputCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.memory.condenser.condenser import Condenser
|
||||
|
||||
|
||||
class BrowserOutputCondenser(Condenser):
|
||||
"""A condenser that masks the observations from browser outputs outside of a recent attention window.
|
||||
|
||||
The intent here is to mask just the browser outputs and leave everything else untouched. This is important because currently we provide screenshots and accessibility trees as input to the model for browser observations. These are really large and consume a lot of tokens without any benefits in performance. So we want to mask all such observations from all previous timesteps, and leave only the most recent one in context.
|
||||
"""
|
||||
|
||||
def __init__(self, attention_window: int = 1):
|
||||
self.attention_window = attention_window
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Replace the content of browser observations outside of the attention window with a placeholder."""
|
||||
results: list[Event] = []
|
||||
cnt: int = 0
|
||||
for event in reversed(events):
|
||||
if (
|
||||
isinstance(event, BrowserOutputObservation)
|
||||
and cnt >= self.attention_window
|
||||
):
|
||||
results.append(
|
||||
AgentCondensationObservation(
|
||||
f'Current URL: {event.url}\nContent Omitted'
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(event)
|
||||
if isinstance(event, BrowserOutputObservation):
|
||||
cnt += 1
|
||||
|
||||
return list(reversed(results))
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: BrowserOutputCondenserConfig
|
||||
) -> BrowserOutputCondenser:
|
||||
return BrowserOutputCondenser(**config.model_dump(exclude=['type']))
|
||||
|
||||
|
||||
BrowserOutputCondenser.register_config(BrowserOutputCondenserConfig)
|
||||
@@ -1,116 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser.condenser import RollingCondenser
|
||||
|
||||
|
||||
class ImportantEventSelection(BaseModel):
|
||||
"""Utility class for the `LLMAttentionCondenser` that forces the LLM to return a list of integers."""
|
||||
|
||||
ids: list[int]
|
||||
|
||||
|
||||
class LLMAttentionCondenser(RollingCondenser):
|
||||
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
|
||||
|
||||
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({keep_first}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.llm = llm
|
||||
|
||||
# This condenser relies on the `response_schema` feature, which is not supported by all LLMs
|
||||
if not supports_response_schema(
|
||||
model=self.llm.config.model,
|
||||
custom_llm_provider=self.llm.config.custom_llm_provider,
|
||||
):
|
||||
raise ValueError(
|
||||
"The LLM model must support the 'response_schema' parameter to use the LLMAttentionCondenser."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""If the history is too long, use an LLM to select the most important events."""
|
||||
if len(events) <= self.max_size:
|
||||
return events
|
||||
|
||||
target_size = self.max_size // 2
|
||||
head = events[: self.keep_first]
|
||||
|
||||
events_from_tail = target_size - len(head)
|
||||
|
||||
message: str = """You will be given a list of actions, observations, and thoughts from a coding agent.
|
||||
Each item in the list has an identifier. Please sort the identifiers in order of how important the
|
||||
contents of the item are for the next step of the coding agent's task, from most important to least
|
||||
important."""
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{'content': message, 'role': 'user'},
|
||||
*[
|
||||
{
|
||||
'content': f'<ID>{e.id}</ID>\n<CONTENT>{e.message}</CONTENT>',
|
||||
'role': 'user',
|
||||
}
|
||||
for e in events
|
||||
],
|
||||
],
|
||||
response_format={
|
||||
'type': 'json_schema',
|
||||
'json_schema': {
|
||||
'name': 'ImportantEventSelection',
|
||||
'schema': ImportantEventSelection.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
response_ids = ImportantEventSelection.model_validate_json(
|
||||
response.choices[0].message.content
|
||||
).ids
|
||||
|
||||
self.add_metadata('all_event_ids', [event.id for event in events])
|
||||
self.add_metadata('response_ids', response_ids)
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
# Filter out any IDs from the head and trim the results down
|
||||
head_ids = [event.id for event in head]
|
||||
response_ids = [
|
||||
response_id for response_id in response_ids if response_id not in head_ids
|
||||
][:events_from_tail]
|
||||
|
||||
# If the response IDs aren't _long_ enough, iterate backwards through the events and add any unfound IDs to the list.
|
||||
for event in reversed(events):
|
||||
if len(response_ids) >= events_from_tail:
|
||||
break
|
||||
if event.id not in response_ids:
|
||||
response_ids.append(event.id)
|
||||
|
||||
# Grab the events associated with the response IDs
|
||||
tail = [event for event in events if event.id in response_ids]
|
||||
|
||||
return head + tail
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser:
|
||||
return LLMAttentionCondenser(
|
||||
llm=LLM(config=config.llm_config),
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
)
|
||||
|
||||
|
||||
LLMAttentionCondenser.register_config(LLMAttentionCondenserConfig)
|
||||
@@ -1,119 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import LLMSummarizingCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm import LLM
|
||||
from openhands.memory.condenser.condenser import RollingCondenser
|
||||
|
||||
|
||||
class LLMSummarizingCondenser(RollingCondenser):
|
||||
"""A condenser that summarizes forgotten events.
|
||||
|
||||
Maintains a condensed history and forgets old events when it grows too large,
|
||||
keeping a special summarization event after the prefix that summarizes all previous summarizations
|
||||
and newly forgotten events.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.llm = llm
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Apply the amortized forgetting strategy with LLM summarization to the given list of events."""
|
||||
if len(events) <= self.max_size:
|
||||
return events
|
||||
|
||||
head = events[: self.keep_first]
|
||||
|
||||
target_size = self.max_size // 2
|
||||
events_from_tail = target_size - len(head)
|
||||
tail = events[-events_from_tail:]
|
||||
|
||||
summary_event = (
|
||||
events[self.keep_first]
|
||||
if isinstance(events[self.keep_first], AgentCondensationObservation)
|
||||
else AgentCondensationObservation('No events summarized')
|
||||
)
|
||||
|
||||
# Identify events to be forgotten (those not in head or tail)
|
||||
forgotten_events = []
|
||||
for event in events[self.keep_first : -events_from_tail]:
|
||||
if not isinstance(event, AgentCondensationObservation):
|
||||
forgotten_events.append(event)
|
||||
|
||||
# Construct prompt for summarization
|
||||
prompt = """You are maintaining state history for an LLM-based code agent. Track:
|
||||
|
||||
USER_CONTEXT: (Preserve essential user requirements, problem descriptions, and clarifications in concise form)
|
||||
|
||||
STATE: {File paths, function signatures, data structures}
|
||||
TESTS: {Failing cases, error messages, outputs}
|
||||
CHANGES: {Code edits, variable updates}
|
||||
DEPS: {Dependencies, imports, external calls}
|
||||
INTENT: {Why changes were made, acceptance criteria}
|
||||
|
||||
PRIORITIZE:
|
||||
1. Capture key user requirements and constraints
|
||||
2. Maintain critical problem context
|
||||
3. Keep all sections concise
|
||||
|
||||
SKIP: {Git clones, build logs, file listings}
|
||||
|
||||
Example history format:
|
||||
USER_CONTEXT: Fix FITS card float representation - "0.009125" becomes "0.009124999999999999" causing comment truncation. Use Python's str() when possible while maintaining FITS compliance.
|
||||
|
||||
STATE: mod_float() in card.py updated
|
||||
TESTS: test_format() passed
|
||||
CHANGES: str(val) replaces f"{val:.16G}"
|
||||
DEPS: None modified
|
||||
INTENT: Fix precision while maintaining FITS compliance"""
|
||||
|
||||
prompt + '\n\n'
|
||||
|
||||
prompt += ('\n' + summary_event.message + '\n') if summary_event.message else ''
|
||||
|
||||
prompt + '\n\n'
|
||||
|
||||
for forgotten_event in forgotten_events:
|
||||
prompt += str(forgotten_event) + '\n\n'
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{
|
||||
'content': prompt,
|
||||
'role': 'user',
|
||||
},
|
||||
],
|
||||
)
|
||||
summary = response.choices[0].message.content
|
||||
|
||||
self.add_metadata('response', response.model_dump())
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
return head + [AgentCondensationObservation(summary)] + tail
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: LLMSummarizingCondenserConfig
|
||||
) -> LLMSummarizingCondenser:
|
||||
return LLMSummarizingCondenser(
|
||||
llm=LLM(config=config.llm_config),
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
)
|
||||
|
||||
|
||||
LLMSummarizingCondenser.register_config(LLMSummarizingCondenserConfig)
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.memory.condenser.condenser import Condenser
|
||||
|
||||
|
||||
class NoOpCondenser(Condenser):
|
||||
"""A condenser that does nothing to the event sequence."""
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Returns the list of events unchanged."""
|
||||
return events
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:
|
||||
return NoOpCondenser()
|
||||
|
||||
|
||||
NoOpCondenser.register_config(NoOpCondenserConfig)
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import ObservationMaskingCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.memory.condenser.condenser import Condenser
|
||||
|
||||
|
||||
class ObservationMaskingCondenser(Condenser):
|
||||
"""A condenser that masks the values of observations outside of a recent attention window."""
|
||||
|
||||
def __init__(self, attention_window: int = 5):
|
||||
self.attention_window = attention_window
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Replace the content of observations outside of the attention window with a placeholder."""
|
||||
results: list[Event] = []
|
||||
for i, event in enumerate(events):
|
||||
if (
|
||||
isinstance(event, Observation)
|
||||
and i < len(events) - self.attention_window
|
||||
):
|
||||
results.append(AgentCondensationObservation('<MASKED>'))
|
||||
else:
|
||||
results.append(event)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: ObservationMaskingCondenserConfig
|
||||
) -> ObservationMaskingCondenser:
|
||||
return ObservationMaskingCondenser(**config.model_dump(exclude=['type']))
|
||||
|
||||
|
||||
ObservationMaskingCondenser.register_config(ObservationMaskingCondenserConfig)
|
||||
@@ -1,29 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.memory.condenser.condenser import Condenser
|
||||
|
||||
|
||||
class RecentEventsCondenser(Condenser):
|
||||
"""A condenser that only keeps a certain number of the most recent events."""
|
||||
|
||||
def __init__(self, keep_first: int = 1, max_events: int = 10):
|
||||
self.keep_first = keep_first
|
||||
self.max_events = max_events
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> list[Event]:
|
||||
"""Keep only the most recent events (up to `max_events`)."""
|
||||
head = events[: self.keep_first]
|
||||
tail_length = max(0, self.max_events - len(head))
|
||||
tail = events[-tail_length:]
|
||||
return head + tail
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser:
|
||||
return RecentEventsCondenser(**config.model_dump(exclude=['type']))
|
||||
|
||||
|
||||
RecentEventsCondenser.register_config(RecentEventsCondenserConfig)
|
||||
@@ -1,406 +0,0 @@
|
||||
from litellm import ModelResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentThinkAction,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
AgentThinkObservation,
|
||||
BrowserOutputObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
IPythonRunCellObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class ConversationMemory:
|
||||
"""Processes event history into a coherent conversation for the agent."""
|
||||
|
||||
def __init__(self, prompt_manager: PromptManager):
|
||||
self.prompt_manager = prompt_manager
|
||||
|
||||
def process_events(
|
||||
self,
|
||||
condensed_history: list[Event],
|
||||
initial_messages: list[Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Process state history into a list of messages for the LLM.
|
||||
|
||||
Ensures that tool call actions are processed correctly in function calling mode.
|
||||
|
||||
Args:
|
||||
state: The state containing the history of events to convert
|
||||
condensed_history: The condensed list of events to process
|
||||
initial_messages: The initial messages to include in the result
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
|
||||
"""
|
||||
events = condensed_history
|
||||
|
||||
# Process special events first (system prompts, etc.)
|
||||
messages = initial_messages
|
||||
|
||||
# Process regular events
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
|
||||
for event in events:
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
messages_to_add = self._process_action(
|
||||
action=event,
|
||||
pending_tool_call_action_messages=pending_tool_call_action_messages,
|
||||
vision_is_active=vision_is_active,
|
||||
)
|
||||
elif isinstance(event, Observation):
|
||||
messages_to_add = self._process_observation(
|
||||
obs=event,
|
||||
tool_call_id_to_message=tool_call_id_to_message,
|
||||
max_message_chars=max_message_chars,
|
||||
vision_is_active=vision_is_active,
|
||||
enable_som_visual_browsing=enable_som_visual_browsing,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
|
||||
# Check pending tool call action messages and see if they are complete
|
||||
_response_ids_to_remove = []
|
||||
for (
|
||||
response_id,
|
||||
pending_message,
|
||||
) in pending_tool_call_action_messages.items():
|
||||
assert pending_message.tool_calls is not None, (
|
||||
'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. '
|
||||
f'Pending message: {pending_message}'
|
||||
)
|
||||
if all(
|
||||
tool_call.id in tool_call_id_to_message
|
||||
for tool_call in pending_message.tool_calls
|
||||
):
|
||||
# If complete:
|
||||
# -- 1. Add the message that **initiated** the tool calls
|
||||
messages_to_add.append(pending_message)
|
||||
# -- 2. Add the tool calls **results***
|
||||
for tool_call in pending_message.tool_calls:
|
||||
messages_to_add.append(tool_call_id_to_message[tool_call.id])
|
||||
tool_call_id_to_message.pop(tool_call.id)
|
||||
_response_ids_to_remove.append(response_id)
|
||||
# Cleanup the processed pending tool messages
|
||||
for response_id in _response_ids_to_remove:
|
||||
pending_tool_call_action_messages.pop(response_id)
|
||||
|
||||
messages += messages_to_add
|
||||
|
||||
return messages
|
||||
|
||||
def process_initial_messages(self, with_caching: bool = False) -> list[Message]:
|
||||
"""Create the initial messages for the conversation."""
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[
|
||||
TextContent(
|
||||
text=self.prompt_manager.get_system_message(),
|
||||
cache_prompt=with_caching,
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def _process_action(
|
||||
self,
|
||||
action: Action,
|
||||
pending_tool_call_action_messages: dict[str, Message],
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Converts an action into a message format that can be sent to the LLM.
|
||||
|
||||
This method handles different types of actions and formats them appropriately:
|
||||
1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish:
|
||||
- In function calling mode: Stores the LLM's response in pending_tool_call_action_messages
|
||||
- In non-function calling mode: Creates a message with the action string
|
||||
2. For MessageActions: Creates a message with the text content and optional image content
|
||||
|
||||
Args:
|
||||
action: The action to convert. Can be one of:
|
||||
- CmdRunAction: For executing bash commands
|
||||
- IPythonRunCellAction: For running IPython code
|
||||
- FileEditAction: For editing files
|
||||
- FileReadAction: For reading files using openhands-aci commands
|
||||
- BrowseInteractiveAction: For browsing the web
|
||||
- AgentFinishAction: For ending the interaction
|
||||
- MessageAction: For sending messages
|
||||
|
||||
pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages.
|
||||
Used in function calling mode to track tool calls that are waiting for their results.
|
||||
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||
|
||||
Returns:
|
||||
list[Message]: A list containing the formatted message(s) for the action.
|
||||
May be empty if the action is handled as a tool call in function calling mode.
|
||||
|
||||
Note:
|
||||
In function calling mode, tool-based actions are stored in pending_tool_call_action_messages
|
||||
rather than being returned immediately. They will be processed later when all corresponding
|
||||
tool call results are available.
|
||||
"""
|
||||
# create a regular message from an event
|
||||
if isinstance(
|
||||
action,
|
||||
(
|
||||
AgentDelegateAction,
|
||||
AgentThinkAction,
|
||||
IPythonRunCellAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
),
|
||||
) or (isinstance(action, CmdRunAction) and action.source == 'agent'):
|
||||
tool_metadata = action.tool_call_metadata
|
||||
assert tool_metadata is not None, (
|
||||
'Tool call metadata should NOT be None when function calling is enabled. Action: '
|
||||
+ str(action)
|
||||
)
|
||||
|
||||
llm_response: ModelResponse = tool_metadata.model_response
|
||||
assistant_msg = getattr(llm_response.choices[0], 'message')
|
||||
|
||||
# Add the LLM message (assistant) that initiated the tool calls
|
||||
# (overwrites any previous message with the same response_id)
|
||||
logger.debug(
|
||||
f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}'
|
||||
)
|
||||
pending_tool_call_action_messages[llm_response.id] = Message(
|
||||
role=getattr(assistant_msg, 'role', 'assistant'),
|
||||
# tool call content SHOULD BE a string
|
||||
content=[TextContent(text=assistant_msg.content or '')]
|
||||
if assistant_msg.content is not None
|
||||
else [],
|
||||
tool_calls=assistant_msg.tool_calls,
|
||||
)
|
||||
return []
|
||||
elif isinstance(action, AgentFinishAction):
|
||||
role = 'user' if action.source == 'user' else 'assistant'
|
||||
|
||||
# when agent finishes, it has tool_metadata
|
||||
# which has already been executed, and it doesn't have a response
|
||||
# when the user finishes (/exit), we don't have tool_metadata
|
||||
tool_metadata = action.tool_call_metadata
|
||||
if tool_metadata is not None:
|
||||
# take the response message from the tool call
|
||||
assistant_msg = getattr(
|
||||
tool_metadata.model_response.choices[0], 'message'
|
||||
)
|
||||
content = assistant_msg.content or ''
|
||||
|
||||
# save content if any, to thought
|
||||
if action.thought:
|
||||
if action.thought != content:
|
||||
action.thought += '\n' + content
|
||||
else:
|
||||
action.thought = content
|
||||
|
||||
# remove the tool call metadata
|
||||
action.tool_call_metadata = None
|
||||
if role not in ('user', 'system', 'assistant', 'tool'):
|
||||
raise ValueError(f'Invalid role: {role}')
|
||||
return [
|
||||
Message(
|
||||
role=role, # type: ignore[arg-type]
|
||||
content=[TextContent(text=action.thought)],
|
||||
)
|
||||
]
|
||||
elif isinstance(action, MessageAction):
|
||||
role = 'user' if action.source == 'user' else 'assistant'
|
||||
content = [TextContent(text=action.content or '')]
|
||||
if vision_is_active and action.image_urls:
|
||||
content.append(ImageContent(image_urls=action.image_urls))
|
||||
if role not in ('user', 'system', 'assistant', 'tool'):
|
||||
raise ValueError(f'Invalid role: {role}')
|
||||
return [
|
||||
Message(
|
||||
role=role, # type: ignore[arg-type]
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
elif isinstance(action, CmdRunAction) and action.source == 'user':
|
||||
content = [
|
||||
TextContent(text=f'User executed the command:\n{action.command}')
|
||||
]
|
||||
return [
|
||||
Message(
|
||||
role='user', # Always user for CmdRunAction
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def _process_observation(
|
||||
self,
|
||||
obs: Observation,
|
||||
tool_call_id_to_message: dict[str, Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Converts an observation into a message format that can be sent to the LLM.
|
||||
|
||||
This method handles different types of observations and formats them appropriately:
|
||||
- CmdOutputObservation: Formats command execution results with exit codes
|
||||
- IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images
|
||||
- FileEditObservation: Formats file editing results
|
||||
- FileReadObservation: Formats file reading results from openhands-aci
|
||||
- AgentDelegateObservation: Formats results from delegated agent tasks
|
||||
- ErrorObservation: Formats error messages from failed actions
|
||||
- UserRejectObservation: Formats user rejection messages
|
||||
|
||||
In function calling mode, observations with tool_call_metadata are stored in
|
||||
tool_call_id_to_message for later processing instead of being returned immediately.
|
||||
|
||||
Args:
|
||||
obs: The observation to convert
|
||||
tool_call_id_to_message: Dictionary mapping tool call IDs to their corresponding messages (used in function calling mode)
|
||||
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
|
||||
|
||||
Returns:
|
||||
list[Message]: A list containing the formatted message(s) for the observation.
|
||||
May be empty if the observation is handled as a tool response in function calling mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If the observation type is unknown
|
||||
"""
|
||||
message: Message
|
||||
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
# if it doesn't have tool call metadata, it was triggered by a user action
|
||||
if obs.tool_call_metadata is None:
|
||||
text = truncate_content(
|
||||
f'\nObserved result of command executed by user:\n{obs.to_agent_observation()}',
|
||||
max_message_chars,
|
||||
)
|
||||
else:
|
||||
text = truncate_content(obs.to_agent_observation(), max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, IPythonRunCellObservation):
|
||||
text = obs.content
|
||||
# replace base64 images with a placeholder
|
||||
splitted = text.split('\n')
|
||||
for i, line in enumerate(splitted):
|
||||
if ' already displayed to user'
|
||||
)
|
||||
text = '\n'.join(splitted)
|
||||
text = truncate_content(text, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, FileEditObservation):
|
||||
text = truncate_content(str(obs), max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, FileReadObservation):
|
||||
message = Message(
|
||||
role='user', content=[TextContent(text=obs.content)]
|
||||
) # Content is already truncated by openhands-aci
|
||||
elif isinstance(obs, BrowserOutputObservation):
|
||||
text = obs.get_agent_obs_text()
|
||||
if (
|
||||
obs.trigger_by_action == ActionType.BROWSE_INTERACTIVE
|
||||
and obs.set_of_marks is not None
|
||||
and len(obs.set_of_marks) > 0
|
||||
and enable_som_visual_browsing
|
||||
and vision_is_active
|
||||
):
|
||||
text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n'
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text=text),
|
||||
ImageContent(image_urls=[obs.set_of_marks]),
|
||||
],
|
||||
)
|
||||
else:
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[TextContent(text=text)],
|
||||
)
|
||||
elif isinstance(obs, AgentDelegateObservation):
|
||||
text = truncate_content(
|
||||
obs.outputs['content'] if 'content' in obs.outputs else '',
|
||||
max_message_chars,
|
||||
)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, AgentThinkObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, ErrorObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
text += '\n[Error occurred in processing last action]'
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, UserRejectObservation):
|
||||
text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
|
||||
text += '\n[Last action has been rejected by the user]'
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, AgentCondensationObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
else:
|
||||
# If an observation message is not returned, it will cause an error
|
||||
# when the LLM tries to return the next message
|
||||
raise ValueError(f'Unknown observation type: {type(obs)}')
|
||||
|
||||
# Update the message as tool response properly
|
||||
if (tool_call_metadata := getattr(obs, 'tool_call_metadata', None)) is not None:
|
||||
tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message(
|
||||
role='tool',
|
||||
content=message.content,
|
||||
tool_call_id=tool_call_metadata.tool_call_id,
|
||||
name=tool_call_metadata.function_name,
|
||||
)
|
||||
# No need to return the observation message
|
||||
# because it will be added by get_action_message when all the corresponding
|
||||
# tool calls in the SAME request are processed
|
||||
return []
|
||||
|
||||
return [message]
|
||||
|
||||
def apply_prompt_caching(self, messages: list[Message]) -> None:
|
||||
"""Applies caching breakpoints to the messages.
|
||||
|
||||
For new Anthropic API, we only need to mark the last user or tool message as cacheable.
|
||||
"""
|
||||
# NOTE: this is only needed for anthropic
|
||||
for message in reversed(messages):
|
||||
if message.role in ('user', 'tool'):
|
||||
message.content[
|
||||
-1
|
||||
].cache_prompt = True # Last item inside the message content
|
||||
break
|
||||
@@ -1,188 +0,0 @@
|
||||
import json
|
||||
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_to_memory
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.utils.embeddings import (
|
||||
LLAMA_INDEX_AVAILABLE,
|
||||
EmbeddingsLoader,
|
||||
check_llama_index,
|
||||
)
|
||||
|
||||
# Conditional imports based on llama_index availability
|
||||
if LLAMA_INDEX_AVAILABLE:
|
||||
import chromadb
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.indices.vector_store.retrievers.retriever import (
|
||||
VectorIndexRetriever,
|
||||
)
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class LongTermMemory:
|
||||
"""Handles storing information for the agent to access later, using chromadb."""
|
||||
|
||||
event_stream: EventStream
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
agent_config: AgentConfig,
|
||||
event_stream: EventStream,
|
||||
):
|
||||
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
|
||||
|
||||
check_llama_index()
|
||||
|
||||
# initialize the chromadb client
|
||||
db = chromadb.PersistentClient(
|
||||
path=f'./cache/sessions/{event_stream.sid}/memory',
|
||||
# FIXME anonymized_telemetry=False,
|
||||
)
|
||||
self.collection = db.get_or_create_collection(name='memories')
|
||||
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
||||
|
||||
# embedding model
|
||||
embedding_strategy = llm_config.embedding_model
|
||||
self.embed_model = EmbeddingsLoader.get_embedding_model(
|
||||
embedding_strategy, llm_config
|
||||
)
|
||||
logger.debug(f'Using embedding model: {self.embed_model}')
|
||||
|
||||
# instantiate the index
|
||||
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
|
||||
self.thought_idx = 0
|
||||
|
||||
# initialize the event stream
|
||||
self.event_stream = event_stream
|
||||
|
||||
# max of threads to run the pipeline
|
||||
self.memory_max_threads = agent_config.memory_max_threads
|
||||
|
||||
def add_event(self, event: Event):
|
||||
"""Adds a new event to the long term memory with a unique id.
|
||||
|
||||
Parameters:
|
||||
- event: The new event to be added to memory
|
||||
"""
|
||||
try:
|
||||
# convert the event to a memory-friendly format, and don't truncate
|
||||
event_data = event_to_memory(event, -1)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f'Failed to process event: {e}')
|
||||
return
|
||||
|
||||
# determine the event type and ID
|
||||
event_type = ''
|
||||
event_id = ''
|
||||
if 'action' in event_data:
|
||||
event_type = 'action'
|
||||
event_id = event_data['action']
|
||||
elif 'observation' in event_data:
|
||||
event_type = 'observation'
|
||||
event_id = event_data['observation']
|
||||
|
||||
# create a Document instance for the event
|
||||
doc = Document(
|
||||
text=json.dumps(event_data),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': event_type,
|
||||
'id': event_id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
self.thought_idx += 1
|
||||
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
|
||||
self._add_document(document=doc)
|
||||
|
||||
def _add_document(self, document: 'Document'):
|
||||
"""Inserts a single document into the index."""
|
||||
self.index.insert_nodes([self._create_node(document)])
|
||||
|
||||
def _create_node(self, document: 'Document') -> 'TextNode':
|
||||
"""Create a TextNode from a Document instance."""
|
||||
return TextNode(
|
||||
text=document.text,
|
||||
doc_id=document.doc_id,
|
||||
extra_info=document.extra_info,
|
||||
)
|
||||
|
||||
def search(self, query: str, k: int = 10) -> list[str]:
|
||||
"""Searches through the current memory using VectorIndexRetriever.
|
||||
|
||||
Parameters:
|
||||
- query (str): A query to match search results to
|
||||
- k (int): Number of top results to return
|
||||
|
||||
Returns:
|
||||
- list[str]: List of top k results found in current memory
|
||||
"""
|
||||
retriever = VectorIndexRetriever(
|
||||
index=self.index,
|
||||
similarity_top_k=k,
|
||||
)
|
||||
results = retriever.retrieve(query)
|
||||
|
||||
for result in results:
|
||||
logger.debug(
|
||||
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
|
||||
)
|
||||
|
||||
return [r.get_text() for r in results]
|
||||
|
||||
def _events_to_docs(self) -> list['Document']:
|
||||
"""Convert all events from the EventStream to documents for batch insert into the index."""
|
||||
try:
|
||||
events = self.event_stream.get_events()
|
||||
except Exception as e:
|
||||
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
|
||||
return []
|
||||
|
||||
documents: list[Document] = []
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
# convert the event to a memory-friendly format, and don't truncate
|
||||
event_data = event_to_memory(event, -1)
|
||||
|
||||
# determine the event type and ID
|
||||
event_type = ''
|
||||
event_id = ''
|
||||
if 'action' in event_data:
|
||||
event_type = 'action'
|
||||
event_id = event_data['action']
|
||||
elif 'observation' in event_data:
|
||||
event_type = 'observation'
|
||||
event_id = event_data['observation']
|
||||
|
||||
# create a Document instance for the event
|
||||
doc = Document(
|
||||
text=json.dumps(event_data),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': event_type,
|
||||
'id': event_id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
self.thought_idx += 1
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f'Failed to process event: {e}')
|
||||
continue
|
||||
|
||||
if documents:
|
||||
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
|
||||
else:
|
||||
logger.debug('No valid documents found to insert into the index.')
|
||||
|
||||
return documents
|
||||
|
||||
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
|
||||
"""Create nodes from a list of documents."""
|
||||
return [self._create_node(doc) for doc in documents]
|
||||
@@ -1,187 +0,0 @@
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
try:
|
||||
# check if those we need later are available using importlib
|
||||
if importlib.util.find_spec('chromadb') is None:
|
||||
raise ImportError(
|
||||
'chromadb is not available. Please install it using poetry install --with llama-index'
|
||||
)
|
||||
|
||||
if (
|
||||
importlib.util.find_spec(
|
||||
'llama_index.core.indices.vector_store.retrievers.retriever'
|
||||
)
|
||||
is None
|
||||
or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
|
||||
is None
|
||||
):
|
||||
raise ImportError(
|
||||
'llama_index is not available. Please install it using poetry install --with llama-index'
|
||||
)
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
LLAMA_INDEX_AVAILABLE = True
|
||||
|
||||
except ImportError:
|
||||
LLAMA_INDEX_AVAILABLE = False
|
||||
|
||||
# Define supported embedding models
|
||||
SUPPORTED_OLLAMA_EMBED_MODELS = [
|
||||
'llama2',
|
||||
'mxbai-embed-large',
|
||||
'nomic-embed-text',
|
||||
'all-minilm',
|
||||
'stable-code',
|
||||
'bge-m3',
|
||||
'bge-large',
|
||||
'paraphrase-multilingual',
|
||||
'snowflake-arctic-embed',
|
||||
]
|
||||
|
||||
|
||||
def check_llama_index():
|
||||
"""Utility function to check the availability of llama_index.
|
||||
|
||||
Raises:
|
||||
ImportError: If llama_index is not available.
|
||||
"""
|
||||
if not LLAMA_INDEX_AVAILABLE:
|
||||
raise ImportError(
|
||||
'llama_index and its dependencies are not installed. '
|
||||
'To use memory features, please run: poetry install --with llama-index.'
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsLoader:
|
||||
"""Loader for embedding model initialization."""
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
|
||||
"""Initialize and return the appropriate embedding model based on the strategy.
|
||||
|
||||
Parameters:
|
||||
- strategy: The embedding strategy to use.
|
||||
- llm_config: Configuration for the LLM.
|
||||
|
||||
Returns:
|
||||
- An instance of the selected embedding model or None.
|
||||
"""
|
||||
|
||||
if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
return OllamaEmbedding(
|
||||
model_name=strategy,
|
||||
base_url=llm_config.embedding_base_url,
|
||||
ollama_additional_kwargs={'mirostat': 0},
|
||||
)
|
||||
elif strategy == 'openai':
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
api_key=llm_config.api_key.get_secret_value()
|
||||
if llm_config.api_key
|
||||
else None,
|
||||
)
|
||||
elif strategy == 'azureopenai':
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
return AzureOpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
deployment_name=llm_config.embedding_deployment_name,
|
||||
api_key=llm_config.api_key.get_secret_value()
|
||||
if llm_config.api_key
|
||||
else None,
|
||||
azure_endpoint=llm_config.base_url,
|
||||
api_version=llm_config.api_version,
|
||||
)
|
||||
elif strategy == 'voyage':
|
||||
from llama_index.embeddings.voyageai import VoyageEmbedding
|
||||
|
||||
return VoyageEmbedding(
|
||||
model_name='voyage-code-3',
|
||||
)
|
||||
elif (strategy is not None) and (strategy.lower() == 'none'):
|
||||
# TODO: this works but is not elegant enough. The incentive is when
|
||||
# an agent using embeddings is not used, there is no reason we need to
|
||||
# initialize an embedding model
|
||||
return None
|
||||
else:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
# initialize the local embedding model
|
||||
local_embed_model = HuggingFaceEmbedding(
|
||||
model_name='BAAI/bge-small-en-v1.5'
|
||||
)
|
||||
|
||||
# for local embeddings, we need torch
|
||||
import torch
|
||||
|
||||
# choose the best device
|
||||
# first determine what is available: CUDA, MPS, or CPU
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = 'mps'
|
||||
else:
|
||||
device = 'cpu'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
os.environ['PYTORCH_FORCE_CPU'] = (
|
||||
'1' # try to force CPU to avoid errors
|
||||
)
|
||||
|
||||
# override CUDA availability
|
||||
torch.cuda.is_available = lambda: False
|
||||
|
||||
# disable MPS to avoid errors
|
||||
if device != 'mps' and hasattr(torch.backends, 'mps'):
|
||||
torch.backends.mps.is_available = lambda: False
|
||||
torch.backends.mps.is_built = False
|
||||
|
||||
# the device being used
|
||||
logger.debug(f'Using device for embeddings: {device}')
|
||||
|
||||
return local_embed_model
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Utility functions to run pipelines, split out for profiling
|
||||
# --------------------------------------------------------------------------
|
||||
def run_pipeline(
|
||||
embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
|
||||
) -> list['TextNode']:
|
||||
"""Run a pipeline embedding documents."""
|
||||
|
||||
# set up a pipeline with the transformations to make
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
embed_model,
|
||||
],
|
||||
)
|
||||
|
||||
# run the pipeline with num_workers
|
||||
nodes = pipeline.run(
|
||||
documents=documents, show_progress=True, num_workers=num_workers
|
||||
)
|
||||
return nodes
|
||||
|
||||
|
||||
def insert_batch_docs(
|
||||
index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
|
||||
) -> list['TextNode']:
|
||||
"""Run the document indexing in parallel."""
|
||||
results = Parallel(n_jobs=num_workers, backend='threading')(
|
||||
delayed(index.insert)(doc) for doc in documents
|
||||
)
|
||||
return results
|
||||
@@ -79,16 +79,6 @@ memory-profiler = "^0.61.0"
|
||||
daytona-sdk = "0.9.1"
|
||||
python-json-logger = "^3.2.1"
|
||||
|
||||
[tool.poetry.group.llama-index.dependencies]
|
||||
llama-index = "*"
|
||||
llama-index-vector-stores-chroma = "*"
|
||||
chromadb = "*"
|
||||
llama-index-embeddings-huggingface = "*"
|
||||
torch = "2.5.1"
|
||||
llama-index-embeddings-azure-openai = "*"
|
||||
llama-index-embeddings-ollama = "*"
|
||||
voyageai = "*"
|
||||
llama-index-embeddings-voyageai = "*"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.9.8"
|
||||
|
||||
@@ -1,448 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.files import FileEditObservation, FileReadObservation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.memory.conversation_memory import ConversationMemory
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_memory():
|
||||
prompt_manager = MagicMock(spec=PromptManager)
|
||||
prompt_manager.get_system_message.return_value = 'System message'
|
||||
return ConversationMemory(prompt_manager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
state = MagicMock(spec=State)
|
||||
state.history = []
|
||||
return state
|
||||
|
||||
|
||||
def test_process_initial_messages(conversation_memory):
|
||||
messages = conversation_memory.process_initial_messages(with_caching=False)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
assert messages[0].content[0].cache_prompt is False
|
||||
|
||||
messages = conversation_memory.process_initial_messages(with_caching=True)
|
||||
assert messages[0].content[0].cache_prompt is True
|
||||
|
||||
|
||||
def test_process_events_with_message_action(conversation_memory):
|
||||
user_message = MessageAction(content='Hello')
|
||||
user_message._source = EventSource.USER
|
||||
assistant_message = MessageAction(content='Hi there')
|
||||
assistant_message._source = EventSource.AGENT
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[user_message, assistant_message],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text == 'Hello'
|
||||
assert messages[2].role == 'assistant'
|
||||
assert messages[2].content[0].text == 'Hi there'
|
||||
|
||||
|
||||
def test_process_events_with_cmd_output_observation(conversation_memory):
|
||||
obs = CmdOutputObservation(
|
||||
command='echo hello',
|
||||
content='Command output',
|
||||
metadata=CmdOutputMetadata(
|
||||
exit_code=0,
|
||||
prefix='[THIS IS PREFIX]',
|
||||
suffix='[THIS IS SUFFIX]',
|
||||
),
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Observed result of command executed by user:' in result.content[0].text
|
||||
assert '[Command finished with exit code 0]' in result.content[0].text
|
||||
assert '[THIS IS PREFIX]' in result.content[0].text
|
||||
assert '[THIS IS SUFFIX]' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_ipython_run_cell_observation(conversation_memory):
|
||||
obs = IPythonRunCellObservation(
|
||||
code='plt.plot()',
|
||||
content='IPython output\n',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'IPython output' in result.content[0].text
|
||||
assert (
|
||||
' already displayed to user'
|
||||
in result.content[0].text
|
||||
)
|
||||
assert 'ABC123' not in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
obs = AgentDelegateObservation(
|
||||
content='Content', outputs={'content': 'Delegated agent output'}
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Delegated agent output' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_error_observation(conversation_memory):
|
||||
obs = ErrorObservation('Error message')
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Error message' in result.content[0].text
|
||||
assert 'Error occurred in processing last action' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_unknown_observation(conversation_memory):
|
||||
# Create a mock that inherits from Event but not Action or Observation
|
||||
obs = Mock(spec=Event)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown event type'):
|
||||
conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
|
||||
def test_process_events_with_file_edit_observation(conversation_memory):
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='old content',
|
||||
new_content='new content',
|
||||
content='diff content',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert '[Existing file /test/file.txt is edited with' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_file_read_observation(conversation_memory):
|
||||
obs = FileReadObservation(
|
||||
path='/test/file.txt',
|
||||
content='File content',
|
||||
impl_source=FileReadSource.DEFAULT,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'File content'
|
||||
|
||||
|
||||
def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
obs = BrowserOutputObservation(
|
||||
url='http://example.com',
|
||||
trigger_by_action='browse',
|
||||
screenshot='',
|
||||
content='Page loaded',
|
||||
error=False,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert '[Current URL: http://example.com]' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
obs = UserRejectObservation('Action rejected')
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Action rejected' in result.content[0].text
|
||||
assert '[Last action has been rejected by the user]' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
mock_response = {
|
||||
'id': 'mock_id',
|
||||
'total_calls_in_response': 1,
|
||||
'choices': [{'message': {'content': 'Task completed'}}],
|
||||
}
|
||||
obs = CmdOutputObservation(
|
||||
command='echo hello',
|
||||
content='Command output',
|
||||
command_id=1,
|
||||
exit_code=0,
|
||||
)
|
||||
obs.tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='123',
|
||||
function_name='execute_bash',
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# No direct message when using function calling
|
||||
assert len(messages) == 1 # Only the initial system message
|
||||
|
||||
|
||||
def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
action = MessageAction(
|
||||
content='Message with image',
|
||||
image_urls=['http://example.com/image.jpg'],
|
||||
)
|
||||
action._source = EventSource.AGENT
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 2
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert isinstance(result.content[1], ImageContent)
|
||||
assert result.content[0].text == 'Message with image'
|
||||
assert result.content[1].image_urls == ['http://example.com/image.jpg']
|
||||
|
||||
|
||||
def test_process_events_with_user_cmd_action(conversation_memory):
|
||||
action = CmdRunAction(command='ls -l')
|
||||
action._source = EventSource.USER
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'User executed the command' in result.content[0].text
|
||||
assert 'ls -l' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_agent_finish_action_with_tool_metadata(
|
||||
conversation_memory,
|
||||
):
|
||||
mock_response = {
|
||||
'id': 'mock_id',
|
||||
'total_calls_in_response': 1,
|
||||
'choices': [{'message': {'content': 'Task completed'}}],
|
||||
}
|
||||
|
||||
action = AgentFinishAction(thought='Initial thought')
|
||||
action._source = EventSource.AGENT
|
||||
action.tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='123',
|
||||
function_name='finish',
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Initial thought\nTask completed' in result.content[0].text
|
||||
|
||||
|
||||
def test_apply_prompt_caching(conversation_memory):
|
||||
messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')]),
|
||||
Message(role='user', content=[TextContent(text='User message')]),
|
||||
Message(role='assistant', content=[TextContent(text='Assistant message')]),
|
||||
Message(role='user', content=[TextContent(text='Another user message')]),
|
||||
]
|
||||
|
||||
conversation_memory.apply_prompt_caching(messages)
|
||||
|
||||
# Only the last user message should have cache_prompt=True
|
||||
assert messages[0].content[0].cache_prompt is False
|
||||
assert messages[1].content[0].cache_prompt is False
|
||||
assert messages[2].content[0].cache_prompt is False
|
||||
assert messages[3].content[0].cache_prompt is True
|
||||
@@ -1,251 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.memory.long_term_memory import LongTermMemory
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_config() -> LLMConfig:
|
||||
config = MagicMock(spec=LLMConfig)
|
||||
config.embedding_model = 'test_embedding_model'
|
||||
config.api_key = 'test_api_key'
|
||||
config.api_version = 'v1'
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_config() -> AgentConfig:
|
||||
config = AgentConfig(
|
||||
memory_enabled=True,
|
||||
memory_max_threads=4,
|
||||
llm_config='test_llm_config',
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store() -> FileStore:
|
||||
store = MagicMock(spec=FileStore)
|
||||
store.sid = 'test_session'
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream(mock_file_store: FileStore) -> EventStream:
|
||||
with patch('openhands.events.stream.EventStream') as MockEventStream:
|
||||
instance = MockEventStream.return_value
|
||||
instance.sid = 'test_session'
|
||||
instance.get_events = MagicMock()
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(
|
||||
mock_llm_config: LLMConfig,
|
||||
mock_agent_config: AgentConfig,
|
||||
mock_event_stream: EventStream,
|
||||
) -> LongTermMemory:
|
||||
mod = LongTermMemory.__module__
|
||||
with patch(f'{mod}.chromadb.PersistentClient') as mock_chroma_client:
|
||||
mock_collection = MagicMock()
|
||||
mock_chroma_client.return_value.get_or_create_collection.return_value = (
|
||||
mock_collection
|
||||
)
|
||||
with (
|
||||
patch(f'{mod}.ChromaVectorStore', MagicMock()),
|
||||
patch(f'{mod}.EmbeddingsLoader', MagicMock()),
|
||||
patch(f'{mod}.VectorStoreIndex', MagicMock()),
|
||||
):
|
||||
memory = LongTermMemory(
|
||||
llm_config=mock_llm_config,
|
||||
agent_config=mock_agent_config,
|
||||
event_stream=mock_event_stream,
|
||||
)
|
||||
memory.collection = mock_collection
|
||||
return memory
|
||||
|
||||
|
||||
def _create_action_event(action: str) -> Event:
|
||||
"""Helper function to create an action event."""
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.AGENT
|
||||
event.action = action
|
||||
return event
|
||||
|
||||
|
||||
def _create_observation_event(observation: str) -> Event:
|
||||
"""Helper function to create an observation event."""
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.ENVIRONMENT
|
||||
event.observation = observation
|
||||
return event
|
||||
|
||||
|
||||
def test_add_event_with_action(long_term_memory: LongTermMemory):
|
||||
event = _create_action_event('test_action')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'action'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_action'
|
||||
|
||||
|
||||
def test_add_event_with_observation(long_term_memory: LongTermMemory):
|
||||
event = _create_observation_event('test_observation')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'observation'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_observation'
|
||||
|
||||
|
||||
def test_add_event_with_missing_keys(long_term_memory: LongTermMemory):
|
||||
# Creating an event with additional unexpected attributes
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.AGENT
|
||||
event.action = 'test_action'
|
||||
event.unexpected_key = 'value'
|
||||
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'action'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_action'
|
||||
|
||||
|
||||
def test_events_to_docs_no_events(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
mock_event_stream.get_events.side_effect = FileNotFoundError
|
||||
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# since get_events raises, documents should be empty
|
||||
assert len(documents) == 0
|
||||
|
||||
# thought_idx remains unchanged
|
||||
assert long_term_memory.thought_idx == 0
|
||||
|
||||
|
||||
def test_load_events_into_index_with_invalid_json(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
"""Test loading events with malformed event data."""
|
||||
# Simulate an event that causes event_to_memory to raise a JSONDecodeError
|
||||
with patch(
|
||||
'openhands.memory.long_term_memory.event_to_memory',
|
||||
side_effect=json.JSONDecodeError('Expecting value', '', 0),
|
||||
):
|
||||
event = _create_action_event('invalid_action')
|
||||
mock_event_stream.get_events.return_value = [event]
|
||||
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# since event_to_memory raises, documents should be empty
|
||||
assert len(documents) == 0
|
||||
|
||||
# thought_idx remains unchanged
|
||||
assert long_term_memory.thought_idx == 0
|
||||
|
||||
|
||||
def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory):
|
||||
event = _create_action_event('test_action')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
long_term_memory._add_document.assert_called()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert 'document' in kwargs
|
||||
assert (
|
||||
kwargs['document'].text
|
||||
== '{"source": "agent", "action": "test_action", "args": {}}'
|
||||
)
|
||||
|
||||
|
||||
def test_search_returns_correct_results(long_term_memory: LongTermMemory):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [
|
||||
MagicMock(get_text=MagicMock(return_value='result1')),
|
||||
MagicMock(get_text=MagicMock(return_value='result2')),
|
||||
]
|
||||
with patch(
|
||||
'openhands.memory.long_term_memory.VectorIndexRetriever',
|
||||
return_value=mock_retriever,
|
||||
):
|
||||
results = long_term_memory.search(query='test query', k=2)
|
||||
assert results == ['result1', 'result2']
|
||||
mock_retriever.retrieve.assert_called_once_with('test query')
|
||||
|
||||
|
||||
def test_search_with_no_results(long_term_memory: LongTermMemory):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
with patch(
|
||||
'openhands.memory.long_term_memory.VectorIndexRetriever',
|
||||
return_value=mock_retriever,
|
||||
):
|
||||
results = long_term_memory.search(query='no results', k=5)
|
||||
assert results == []
|
||||
mock_retriever.retrieve.assert_called_once_with('no results')
|
||||
|
||||
|
||||
def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory):
|
||||
event1 = _create_action_event('action1')
|
||||
event2 = _create_observation_event('observation1')
|
||||
long_term_memory.add_event(event1)
|
||||
long_term_memory.add_event(event2)
|
||||
assert long_term_memory.thought_idx == 2
|
||||
|
||||
|
||||
def test_load_events_batch_insert(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
event1 = _create_action_event('action1')
|
||||
event2 = _create_observation_event('observation1')
|
||||
event3 = _create_action_event('action2')
|
||||
mock_event_stream.get_events.return_value = [event1, event2, event3]
|
||||
|
||||
# Mock insert_batch_docs
|
||||
with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs:
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# Mock the insert_batch_docs to simulate document insertion
|
||||
mock_run_docs.return_value = []
|
||||
|
||||
# Call insert_batch_docs with the documents
|
||||
mock_run_docs(
|
||||
index=long_term_memory.index,
|
||||
documents=documents,
|
||||
num_workers=long_term_memory.memory_max_threads,
|
||||
)
|
||||
|
||||
# Assert that insert_batch_docs was called with the correct arguments
|
||||
mock_run_docs.assert_called_once_with(
|
||||
index=long_term_memory.index,
|
||||
documents=documents,
|
||||
num_workers=long_term_memory.memory_max_threads,
|
||||
)
|
||||
|
||||
# Check if thought_idx was incremented correctly
|
||||
assert long_term_memory.thought_idx == 3
|
||||
Reference in New Issue
Block a user