From 80954e4b8d0752fd4772f339dee419fbf1debc6f Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 25 Oct 2023 00:09:25 +0800 Subject: [PATCH] Fix tmp dir not exists (#401) * Fix tmp dir not exists * Update tests to make it more clear * Add check if save path is not None --- autogen/retrieve_utils.py | 3 +++ test/test_retrieve_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index c660fa85d..c29ced376 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -254,7 +254,10 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA def get_file_from_url(url: str, save_path: str = None): """Download a file from a URL.""" if save_path is None: + os.makedirs("/tmp/chromadb", exist_ok=True) save_path = os.path.join("/tmp/chromadb", os.path.basename(url)) + else: + os.makedirs(os.path.dirname(save_path), exist_ok=True) with requests.get(url, stream=True) as r: r.raise_for_status() with open(save_path, "wb") as f: diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index a1c70d9cf..81fb1a096 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -140,7 +140,7 @@ class TestRetrieveUtils: db = lancedb.connect(db_path) table = db.open_table("my_table") query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() - return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()} + return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): results = self.query_vector_db( @@ -166,7 +166,7 @@ class TestRetrieveUtils: create_lancedb() ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") - assert ragragproxyagent._results["ids"] == [3, 1, 5] + assert ragragproxyagent._results["ids"] == [[3, 1, 5]] def test_custom_text_split_function(self): def custom_text_split_function(text):