Refactor embeddings (#4219)

This commit is contained in:
Engel Nyst
2024-10-05 20:59:08 +02:00
committed by GitHub
parent 40d2935911
commit 9d0e6a24bc
7 changed files with 571 additions and 151 deletions

View File

@@ -1,189 +1,187 @@
import threading
import json
from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_to_memory
from openhands.events.stream import EventStream
from openhands.utils.embeddings import (
LLAMA_INDEX_AVAILABLE,
EmbeddingsLoader,
check_llama_index,
)
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.utils.tenacity_stop import stop_if_should_exit
try:
import chromadb
import llama_index.embeddings.openai.base as llama_openai
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.vector_stores.chroma import ChromaVectorStore
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
# Conditional imports based on llama_index availability
if LLAMA_INDEX_AVAILABLE:
# TODO: this could be made configurable
num_retries: int = 10
retry_min_wait: int = 3
retry_max_wait: int = 300
# llama-index includes a retry decorator around openai.get_embeddings() function
# it is initialized with hard-coded values and errors
# this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
# this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
if hasattr(llama_openai.get_embeddings, '__wrapped__'):
original_get_embeddings = llama_openai.get_embeddings.__wrapped__
else:
logger.warning('Cannot set custom retry limits.')
num_retries = 1
original_get_embeddings = llama_openai.get_embeddings
def attempt_on_error(retry_state):
logger.error(
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
exc_info=False,
)
return None
@retry(
reraise=True,
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, InternalServerError)
),
after=attempt_on_error,
import chromadb
from llama_index.core import Document
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
def wrapper_get_embeddings(*args, **kwargs):
return original_get_embeddings(*args, **kwargs)
llama_openai.get_embeddings = wrapper_get_embeddings
class EmbeddingsLoader:
"""Loader for embedding model initialization."""
@staticmethod
def get_embedding_model(strategy: str, llm_config: LLMConfig):
supported_ollama_embed_models = [
'llama2',
'mxbai-embed-large',
'nomic-embed-text',
'all-minilm',
'stable-code',
'bge-m3',
'bge-large',
'paraphrase-multilingual',
'snowflake-arctic-embed',
]
if strategy in supported_ollama_embed_models:
from llama_index.embeddings.ollama import OllamaEmbedding
return OllamaEmbedding(
model_name=strategy,
base_url=llm_config.embedding_base_url,
ollama_additional_kwargs={'mirostat': 0},
)
elif strategy == 'openai':
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
return AzureOpenAIEmbedding(
model='text-embedding-ada-002',
deployment_name=llm_config.embedding_deployment_name,
api_key=llm_config.api_key,
azure_endpoint=llm_config.base_url,
api_version=llm_config.api_version,
)
elif (strategy is not None) and (strategy.lower() == 'none'):
# TODO: this works but is not elegant enough. The incentive is when
# an agent using embeddings is not used, there is no reason we need to
# initialize an embedding model
return None
else:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class LongTermMemory:
"""Handles storing information for the agent to access later, using chromadb."""
def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
if not LLAMA_INDEX_AVAILABLE:
raise ImportError(
'llama_index and its dependencies are not installed. '
'To use LongTermMemory, please run: poetry install --with llama-index'
)
event_stream: EventStream
db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
def __init__(
self,
llm_config: LLMConfig,
agent_config: AgentConfig,
event_stream: EventStream,
):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
check_llama_index()
# initialize the chromadb client
db = chromadb.PersistentClient(
path=f'./cache/sessions/{event_stream.sid}/memory',
# FIXME anonymized_telemetry=False,
)
self.collection = db.get_or_create_collection(name='memories')
vector_store = ChromaVectorStore(chroma_collection=self.collection)
# embedding model
embedding_strategy = llm_config.embedding_model
embed_model = EmbeddingsLoader.get_embedding_model(
self.embed_model = EmbeddingsLoader.get_embedding_model(
embedding_strategy, llm_config
)
self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
self.sema = threading.Semaphore(value=memory_max_threads)
self.thought_idx = 0
self._add_threads: list[threading.Thread] = []
def add_event(self, event: dict):
# instantiate the index
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
self.thought_idx = 0
# initialize the event stream
self.event_stream = event_stream
# max of threads to run the pipeline
self.memory_max_threads = agent_config.memory_max_threads
def add_event(self, event: Event):
"""Adds a new event to the long term memory with a unique id.
Parameters:
- event (dict): The new event to be added to memory
- event: The new event to be added to memory
"""
id = ''
t = ''
if 'action' in event:
t = 'action'
id = event['action']
elif 'observation' in event:
t = 'observation'
id = event['observation']
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
return
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event),
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': t,
'id': id,
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
self.thought_idx += 1
logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
thread = threading.Thread(target=self._add_doc, args=(doc,))
self._add_threads.append(thread)
thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
self._add_document(document=doc)
def _add_doc(self, doc):
with self.sema:
self.index.insert(doc)
def _add_document(self, document: 'Document'):
"""Inserts a single document into the index."""
self.index.insert_nodes([self._create_node(document)])
def search(self, query: str, k: int = 10):
"""Searches through the current memory using VectorIndexRetriever
def _create_node(self, document: 'Document') -> 'TextNode':
"""Create a TextNode from a Document instance."""
return TextNode(
text=document.text,
doc_id=document.doc_id,
extra_info=document.extra_info,
)
def search(self, query: str, k: int = 10) -> list[str]:
"""Searches through the current memory using VectorIndexRetriever.
Parameters:
- query (str): A query to match search results to
- k (int): Number of top results to return
Returns:
- list[str]: list of top k results found in current memory
- list[str]: List of top k results found in current memory
"""
retriever = VectorIndexRetriever(
index=self.index,
similarity_top_k=k,
)
results = retriever.retrieve(query)
for result in results:
logger.debug(
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
)
return [r.get_text() for r in results]
def _events_to_docs(self) -> list['Document']:
"""Convert all events from the EventStream to documents for batch insert into the index."""
try:
events = self.event_stream.get_events()
except Exception as e:
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
return []
documents: list[Document] = []
for event in events:
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
documents.append(doc)
self.thought_idx += 1
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
continue
if documents:
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
else:
logger.debug('No valid documents found to insert into the index.')
return documents
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
"""Create nodes from a list of documents."""
return [self._create_node(doc) for doc in documents]