Add support to customized vectordb and embedding functions (#161)

* Add custom embedding function

* Add support to custom vector db

* Improve docstring

* Improve docstring

* Improve docstring

* Add support to customized is_termination_msg fucntion

* Add a test for customize vector db with lancedb

* Fix tests

* Add test for embedding_function

* Update docstring
This commit is contained in:
Li Jiang
2023-10-10 20:53:18 +08:00
committed by GitHub
parent 37a07a83c3
commit fa6e2a52c0
6 changed files with 192 additions and 15 deletions

View File

@@ -100,6 +100,70 @@ class TestRetrieveUtils:
results = query_vector_db(["autogen"], client=client)
assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", []))
def test_custom_vector_db(self):
try:
import lancedb
except ImportError:
return
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
db_path = "/tmp/lancedb"
def create_lancedb():
db = lancedb.connect(db_path)
data = [
{"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"},
{"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"},
{"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"},
{"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"},
{"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"},
{"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"},
]
try:
db.create_table("my_table", data)
except OSError:
pass
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
self,
query_texts,
n_results=10,
search_string="",
):
if query_texts:
vector = [0.1, 0.3]
db = lancedb.connect(db_path)
table = db.open_table("my_table")
query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df()
return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()}
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
results = self.query_vector_db(
query_texts=[problem],
n_results=n_results,
search_string=search_string,
)
self._results = results
print("doc_ids: ", results["ids"])
ragragproxyagent = MyRetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=2,
retrieve_config={
"task": "qa",
"chunk_token_size": 2000,
"client": "__",
"embedding_model": "all-mpnet-base-v2",
},
)
create_lancedb()
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
assert ragragproxyagent._results["ids"] == [3, 1, 5]
if __name__ == "__main__":
pytest.main()