mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-05 16:05:03 -05:00
Add support to customized vectordb and embedding functions (#161)
* Add custom embedding function * Add support to custom vector db * Improve docstring * Improve docstring * Improve docstring * Add support to customized is_termination_msg fucntion * Add a test for customize vector db with lancedb * Fix tests * Add test for embedding_function * Update docstring
This commit is contained in:
@@ -100,6 +100,70 @@ class TestRetrieveUtils:
|
||||
results = query_vector_db(["autogen"], client=client)
|
||||
assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", []))
|
||||
|
||||
def test_custom_vector_db(self):
|
||||
try:
|
||||
import lancedb
|
||||
except ImportError:
|
||||
return
|
||||
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
|
||||
|
||||
db_path = "/tmp/lancedb"
|
||||
|
||||
def create_lancedb():
|
||||
db = lancedb.connect(db_path)
|
||||
data = [
|
||||
{"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"},
|
||||
{"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"},
|
||||
{"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"},
|
||||
{"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"},
|
||||
{"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"},
|
||||
{"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"},
|
||||
]
|
||||
try:
|
||||
db.create_table("my_table", data)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
|
||||
def query_vector_db(
|
||||
self,
|
||||
query_texts,
|
||||
n_results=10,
|
||||
search_string="",
|
||||
):
|
||||
if query_texts:
|
||||
vector = [0.1, 0.3]
|
||||
db = lancedb.connect(db_path)
|
||||
table = db.open_table("my_table")
|
||||
query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df()
|
||||
return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()}
|
||||
|
||||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
|
||||
results = self.query_vector_db(
|
||||
query_texts=[problem],
|
||||
n_results=n_results,
|
||||
search_string=search_string,
|
||||
)
|
||||
|
||||
self._results = results
|
||||
print("doc_ids: ", results["ids"])
|
||||
|
||||
ragragproxyagent = MyRetrieveUserProxyAgent(
|
||||
name="ragproxyagent",
|
||||
human_input_mode="NEVER",
|
||||
max_consecutive_auto_reply=2,
|
||||
retrieve_config={
|
||||
"task": "qa",
|
||||
"chunk_token_size": 2000,
|
||||
"client": "__",
|
||||
"embedding_model": "all-mpnet-base-v2",
|
||||
},
|
||||
)
|
||||
|
||||
create_lancedb()
|
||||
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
|
||||
assert ragragproxyagent._results["ids"] == [3, 1, 5]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
||||
Reference in New Issue
Block a user