mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Update no_update_context, fix upsert docs (#52)
* Update no_update_context, fix upsert docs * Recreate only once * Add comments to get_or_create --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -125,7 +125,9 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
|
||||
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
|
||||
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
|
||||
- no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False.
|
||||
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
|
||||
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
|
||||
This is the same as that used in chromadb. Default is False.
|
||||
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -148,7 +150,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
|
||||
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
|
||||
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
|
||||
self.no_update_context = self._retrieve_config.get("no_update_context", False)
|
||||
self.update_context = self._retrieve_config.get("update_context", True)
|
||||
self._get_or_create = self._retrieve_config.get("get_or_create", False)
|
||||
self._context_max_tokens = self._max_tokens * 0.8
|
||||
self._collection = False # the collection is not created
|
||||
self._ipython = get_ipython()
|
||||
@@ -231,7 +234,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
"""In this function, we will update the context and reset the conversation based on different conditions.
|
||||
We'll update the context and reset the conversation if no_update_context is False and either of the following:
|
||||
We'll update the context and reset the conversation if update_context is True and either of the following:
|
||||
(1) the last message contains "UPDATE CONTEXT",
|
||||
(2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message.
|
||||
"""
|
||||
@@ -247,7 +250,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
update_context_case2 = (
|
||||
self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper()
|
||||
)
|
||||
if (update_context_case1 or update_context_case2) and not self.no_update_context:
|
||||
if (update_context_case1 or update_context_case2) and self.update_context:
|
||||
print(colored("Updating context and resetting conversation.", "green"), flush=True)
|
||||
# extract the first sentence in the response as the intermediate answer
|
||||
_message = message.get("content", "").split("\n")[0].strip()
|
||||
@@ -286,7 +289,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
return False, None
|
||||
|
||||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
|
||||
if not self._collection:
|
||||
if not self._collection or self._get_or_create:
|
||||
print("Trying to create collection.")
|
||||
create_vector_db_from_dir(
|
||||
dir_path=self._docs_path,
|
||||
@@ -296,8 +299,10 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
chunk_mode=self._chunk_mode,
|
||||
must_break_at_empty_line=self._must_break_at_empty_line,
|
||||
embedding_model=self._embedding_model,
|
||||
get_or_create=self._get_or_create,
|
||||
)
|
||||
self._collection = True
|
||||
self._get_or_create = False
|
||||
|
||||
results = query_vector_db(
|
||||
query_texts=[problem],
|
||||
|
||||
@@ -208,18 +208,13 @@ def create_vector_db_from_dir(
|
||||
|
||||
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
|
||||
print(f"Found {len(chunks)} chunks.")
|
||||
# upsert in batch of 40000
|
||||
for i in range(0, len(chunks), 40000):
|
||||
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
|
||||
for i in range(0, len(chunks), min(40000, len(chunks))):
|
||||
end_idx = i + min(40000, len(chunks) - i)
|
||||
collection.upsert(
|
||||
documents=chunks[
|
||||
i : i + 40000
|
||||
], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
|
||||
ids=[f"doc_{i}" for i in range(i, i + 40000)], # unique for each doc
|
||||
documents=chunks[i:end_idx],
|
||||
ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc
|
||||
)
|
||||
collection.upsert(
|
||||
documents=chunks[i : len(chunks)],
|
||||
ids=[f"doc_{i}" for i in range(i, len(chunks))], # unique for each doc
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"{e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user