From 4ffc57fdaf7957f878bf396676ee23db2566ead5 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Sun, 9 Jun 2024 01:16:58 +0200 Subject: [PATCH] make `web_search` and `google` commands code-friendly --- autogpt/poetry.lock | 41 ++++++++-- autogpt/pyproject.toml | 1 + autogpt/tests/unit/test_web_search.py | 106 ++++++++++++-------------- forge/forge/components/web/search.py | 81 ++++++++------------ forge/poetry.lock | 35 ++++++++- forge/pyproject.toml | 1 + 6 files changed, 148 insertions(+), 117 deletions(-) diff --git a/autogpt/poetry.lock b/autogpt/poetry.lock index f67bf1db87..04cadd69ae 100644 --- a/autogpt/poetry.lock +++ b/autogpt/poetry.lock @@ -1793,22 +1793,38 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-api-python-client" -version = "2.114.0" +version = "2.132.0" description = "Google API Client Library for Python" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-python-client-2.114.0.tar.gz", hash = "sha256:e041bbbf60e682261281e9d64b4660035f04db1cccba19d1d68eebc24d1465ed"}, - {file = "google_api_python_client-2.114.0-py2.py3-none-any.whl", hash = "sha256:690e0bb67d70ff6dea4e8a5d3738639c105a478ac35da153d3b2a384064e9e1a"}, + {file = "google-api-python-client-2.132.0.tar.gz", hash = "sha256:d6340dc83b72d72333cee5d50f7dcfecbff66a8783164090e945f985ec4c374d"}, + {file = "google_api_python_client-2.132.0-py2.py3-none-any.whl", hash = "sha256:cde87700bd4d37f39f5e940292c1c6cd0910990b5b01f50b1332a8cea38e8595"}, ] [package.dependencies] google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" -google-auth = ">=1.19.0,<3.0.0.dev0" -google-auth-httplib2 = ">=0.1.0" -httplib2 = ">=0.15.0,<1.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" uritemplate = ">=3.0.1,<5" +[[package]] +name = "google-api-python-client-stubs" +version = "1.26.0" +description = "Type stubs for google-api-python-client" +optional = false +python-versions = "<4.0,>=3.7" +files = [ + {file = "google_api_python_client_stubs-1.26.0-py3-none-any.whl", hash = "sha256:0614b0cef5beac43e6ab02418f07e64ee66dc99ae4e377d54a155ac261533987"}, + {file = "google_api_python_client_stubs-1.26.0.tar.gz", hash = "sha256:f3b38b46f7b5cf4f6e7cc63ca554a2d23096d49c841f38b9ea553a5237074b56"}, +] + +[package.dependencies] +google-api-python-client = ">=2.130.0" +types-httplib2 = ">=0.22.0.2" +typing-extensions = ">=3.10.0" + [[package]] name = "google-auth" version = "2.26.2" @@ -6244,6 +6260,17 @@ files = [ {file = "types_html5lib-1.1.11.20240106-py3-none-any.whl", hash = "sha256:61993cb89220107481e0f1da65c388ff8cf3d8c5f6e8483c97559639a596b697"}, ] +[[package]] +name = "types-httplib2" +version = "0.22.0.20240310" +description = "Typing stubs for httplib2" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-httplib2-0.22.0.20240310.tar.gz", hash = "sha256:1eda99fea18ec8a1dc1a725ead35b889d0836fec1b11ae6f1fe05440724c1d15"}, + {file = "types_httplib2-0.22.0.20240310-py3-none-any.whl", hash = "sha256:8cd706fc81f0da32789a4373a28df6f39e9d5657d1281db4d2fd22ee29e83661"}, +] + [[package]] name = "types-markdown" version = "3.5.0.20240106" @@ -6972,4 +6999,4 @@ benchmark = ["agbenchmark"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a6022fadc0c861e7947bf4367a76242c83adcd0f963a742e2dd70a0196807bf6" +content-hash = "e76c41c9927cbe6cec2aa7f7cf3263dc136e69cec6c3531630f2cbd75613f9d6" diff --git a/autogpt/pyproject.toml b/autogpt/pyproject.toml index 4a51d708ac..ec91f0babd 100644 --- a/autogpt/pyproject.toml +++ b/autogpt/pyproject.toml @@ -69,6 +69,7 @@ benchmark = ["agbenchmark"] black = "^23.12.1" flake8 = "^7.0.0" gitpython = "^3.1.32" +google-api-python-client-stubs = "^1.26.0" isort = "^5.13.1" pre-commit = "*" pyright = "^1.1.364" diff --git a/autogpt/tests/unit/test_web_search.py b/autogpt/tests/unit/test_web_search.py index 411999c00c..fd6b9d4126 100644 --- a/autogpt/tests/unit/test_web_search.py +++ b/autogpt/tests/unit/test_web_search.py @@ -1,12 +1,19 @@ import json +from typing import TYPE_CHECKING +from unittest.mock import Mock import pytest -from forge.components.web.search import WebSearchComponent +from forge.components.web.search import SearchResult, WebSearchComponent +from forge.config import Config from forge.utils.exceptions import ConfigurationError from googleapiclient.errors import HttpError +from pytest_mock import MockerFixture from autogpt.agents.agent import Agent +if TYPE_CHECKING: + from googleapiclient._apis.customsearch.v1 import Search as GoogleSearch + @pytest.fixture def web_search_component(agent: Agent): @@ -14,56 +21,36 @@ def web_search_component(agent: Agent): @pytest.mark.parametrize( - "query, expected_output", - [("test", "test"), (["test1", "test2"], '["test1", "test2"]')], -) -@pytest.fixture -def test_safe_google_results( - query, expected_output, web_search_component: WebSearchComponent -): - result = web_search_component.safe_google_results(query) - assert isinstance(result, str) - assert result == expected_output - - -@pytest.fixture -def test_safe_google_results_invalid_input(web_search_component: WebSearchComponent): - with pytest.raises(AttributeError): - web_search_component.safe_google_results(123) # type: ignore - - -@pytest.mark.parametrize( - "query, num_results, expected_output_parts, return_value", + "query, num_results, ddg_return_value, expected_output", [ ( "test", 1, - ("Result 1", "https://example.com/result1"), [{"title": "Result 1", "href": "https://example.com/result1"}], + [SearchResult(title="Result 1", url="https://example.com/result1")], ), - ("", 1, (), []), ("no results", 1, (), []), ], ) -def test_google_search( - query, - num_results, - expected_output_parts, - return_value, - mocker, +def test_ddg_search( + query: str, + num_results: int, + ddg_return_value: list[dict], + expected_output: list[SearchResult], + mocker: MockerFixture, web_search_component: WebSearchComponent, ): mock_ddg = mocker.Mock() - mock_ddg.return_value = return_value + mock_ddg.return_value = ddg_return_value mocker.patch("forge.components.web.search.DDGS.text", mock_ddg) actual_output = web_search_component.web_search(query, num_results=num_results) - for o in expected_output_parts: + for o in expected_output: assert o in actual_output @pytest.fixture -def mock_googleapiclient(mocker): +def mock_googleapiclient(mocker: MockerFixture): mock_build = mocker.patch("googleapiclient.discovery.build") mock_service = mocker.Mock() mock_build.return_value = mock_service @@ -71,36 +58,39 @@ def mock_googleapiclient(mocker): @pytest.mark.parametrize( - "query, num_results, search_results, expected_output", + "query, num_results, google_return_value, expected_output", [ ( "test", 3, [ - {"link": "http://example.com/result1"}, - {"link": "http://example.com/result2"}, - {"link": "http://example.com/result3"}, + {"title": "Result 1", "link": "http://example.com/result1"}, + {"title": "Result 2", "link": "http://example.com/result2"}, + {"title": "Result 3", "link": "http://example.com/result3"}, ], [ - "http://example.com/result1", - "http://example.com/result2", - "http://example.com/result3", + SearchResult(title="Result 1", url="http://example.com/result1"), + SearchResult(title="Result 2", url="http://example.com/result2"), + SearchResult(title="Result 3", url="http://example.com/result3"), ], ), - ("", 3, [], []), ], ) -def test_google_official_search( - query, - num_results, - expected_output, - search_results, - mock_googleapiclient, +def test_google_custom_search( + query: str, + num_results: int, + google_return_value: "GoogleSearch", + expected_output: list[SearchResult], + config: Config, + mock_googleapiclient: Mock, web_search_component: WebSearchComponent, ): - mock_googleapiclient.return_value = search_results + config.google_api_key = "mock_api_key" + config.google_custom_search_engine_id = "mock_search_engine_id" + + mock_googleapiclient.return_value = google_return_value actual_output = web_search_component.google(query, num_results=num_results) - assert actual_output == web_search_component.safe_google_results(expected_output) + assert actual_output == expected_output @pytest.mark.parametrize( @@ -122,15 +112,19 @@ def test_google_official_search( ), ], ) -def test_google_official_search_errors( - query, - num_results, - expected_error_type, - mock_googleapiclient, - http_code, - error_msg, +def test_google_custom_search_errors( + query: str, + num_results: int, + expected_error_type: type[Exception], + http_code: int, + error_msg: str, + config: Config, + mock_googleapiclient: Mock, web_search_component: WebSearchComponent, ): + config.google_api_key = "mock_api_key" + config.google_custom_search_engine_id = "mock_search_engine_id" + class resp: def __init__(self, _status, _reason): self.status = _status @@ -140,7 +134,7 @@ def test_google_official_search_errors( "error": {"code": http_code, "message": error_msg, "reason": "backendError"} } error = HttpError( - resp=resp(http_code, error_msg), + resp=resp(http_code, error_msg), # type: ignore content=str.encode(json.dumps(response_content)), uri="https://www.googleapis.com/customsearch/v1?q=invalid+query&cx", ) diff --git a/forge/forge/components/web/search.py b/forge/forge/components/web/search.py index 36d304239e..f89f42c931 100644 --- a/forge/forge/components/web/search.py +++ b/forge/forge/components/web/search.py @@ -4,18 +4,25 @@ import time from typing import Iterator from duckduckgo_search import DDGS +from pydantic import BaseModel from forge.agent.protocols import CommandProvider, DirectiveProvider from forge.command import Command, command from forge.config.config import Config from forge.models.json_schema import JSONSchema -from forge.utils.exceptions import ConfigurationError +from forge.utils.exceptions import ConfigurationError, InvalidArgumentError DUCKDUCKGO_MAX_ATTEMPTS = 3 logger = logging.getLogger(__name__) +class SearchResult(BaseModel): + title: str + url: str + excerpt: str = "" + + class WebSearchComponent(DirectiveProvider, CommandProvider): """Provides commands to search the web.""" @@ -61,7 +68,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider): ), }, ) - def web_search(self, query: str, num_results: int = 8) -> str: + def web_search(self, query: str, num_results: int = 8) -> list[SearchResult]: """Return the results of a Google search Args: @@ -71,13 +78,13 @@ class WebSearchComponent(DirectiveProvider, CommandProvider): Returns: str: The results of the search. """ + if not query: + raise InvalidArgumentError("'query' must be non-empty") + search_results = [] attempts = 0 while attempts < DUCKDUCKGO_MAX_ATTEMPTS: - if not query: - return json.dumps(search_results) - search_results = DDGS().text(query, max_results=num_results) if search_results: @@ -86,23 +93,15 @@ class WebSearchComponent(DirectiveProvider, CommandProvider): time.sleep(1) attempts += 1 - search_results = [ - { - "title": r["title"], - "url": r["href"], - **({"exerpt": r["body"]} if r.get("body") else {}), - } + return [ + SearchResult( + title=r["title"], + excerpt=r.get("body", ""), + url=r["href"], + ) for r in search_results ] - results = ("## Search results\n") + "\n\n".join( - f"### \"{r['title']}\"\n" - f"**URL:** {r['url']} \n" - "**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A") - for r in search_results - ) - return self.safe_google_results(results) - @command( ["google"], "Google Search", @@ -121,7 +120,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider): ), }, ) - def google(self, query: str, num_results: int = 8) -> str | list[str]: + def google(self, query: str, num_results: int = 8) -> list[SearchResult]: """Return the results of a Google search using the official Google API Args: @@ -139,6 +138,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider): # Get the Google API key and Custom Search Engine ID from the config file api_key = self.legacy_config.google_api_key custom_search_engine_id = self.legacy_config.google_custom_search_engine_id + assert api_key and custom_search_engine_id # checked in get_commands() # Initialize the Custom Search API service service = build("customsearch", "v1", developerKey=api_key) @@ -151,44 +151,25 @@ class WebSearchComponent(DirectiveProvider, CommandProvider): ) # Extract the search result items from the response - search_results = result.get("items", []) - - # Create a list of only the URLs from the search results - search_results_links = [item["link"] for item in search_results] + return [ + SearchResult( + title=r.get("title", ""), + excerpt=r.get("snippet", ""), + url=r["link"], + ) + for r in result.get("items", []) + if "link" in r + ] except HttpError as e: # Handle errors in the API call error_details = json.loads(e.content.decode()) # Check if the error is related to an invalid or missing API key - if error_details.get("error", {}).get( - "code" - ) == 403 and "invalid API key" in error_details.get("error", {}).get( - "message", "" + if error_details.get("error", {}).get("code") == 403 and ( + "invalid API key" in error_details["error"].get("message", "") ): raise ConfigurationError( "The provided Google API key is invalid or missing." ) raise - # google_result can be a list or a string depending on the search results - - # Return the list of search result URLs - return self.safe_google_results(search_results_links) - - def safe_google_results(self, results: str | list) -> str: - """ - Return the results of a Google search in a safe format. - - Args: - results (str | list): The search results. - - Returns: - str: The results of the search. - """ - if isinstance(results, list): - safe_message = json.dumps( - [result.encode("utf-8", "ignore").decode("utf-8") for result in results] - ) - else: - safe_message = results.encode("utf-8", "ignore").decode("utf-8") - return safe_message diff --git a/forge/poetry.lock b/forge/poetry.lock index d713a148cf..887ad5211b 100644 --- a/forge/poetry.lock +++ b/forge/poetry.lock @@ -1969,13 +1969,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-api-python-client" -version = "2.129.0" +version = "2.132.0" description = "Google API Client Library for Python" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-python-client-2.129.0.tar.gz", hash = "sha256:984cc8cc8eb4923468b1926d2b8effc5b459a4dda3c845896eb87c153b28ef84"}, - {file = "google_api_python_client-2.129.0-py2.py3-none-any.whl", hash = "sha256:d50f7e2dfdbb7fc2732f6a0cba1c54d7bb676390679526c6bb628c901e43ec86"}, + {file = "google-api-python-client-2.132.0.tar.gz", hash = "sha256:d6340dc83b72d72333cee5d50f7dcfecbff66a8783164090e945f985ec4c374d"}, + {file = "google_api_python_client-2.132.0-py2.py3-none-any.whl", hash = "sha256:cde87700bd4d37f39f5e940292c1c6cd0910990b5b01f50b1332a8cea38e8595"}, ] [package.dependencies] @@ -1985,6 +1985,22 @@ google-auth-httplib2 = ">=0.2.0,<1.0.0" httplib2 = ">=0.19.0,<1.dev0" uritemplate = ">=3.0.1,<5" +[[package]] +name = "google-api-python-client-stubs" +version = "1.26.0" +description = "Type stubs for google-api-python-client" +optional = false +python-versions = "<4.0,>=3.7" +files = [ + {file = "google_api_python_client_stubs-1.26.0-py3-none-any.whl", hash = "sha256:0614b0cef5beac43e6ab02418f07e64ee66dc99ae4e377d54a155ac261533987"}, + {file = "google_api_python_client_stubs-1.26.0.tar.gz", hash = "sha256:f3b38b46f7b5cf4f6e7cc63ca554a2d23096d49c841f38b9ea553a5237074b56"}, +] + +[package.dependencies] +google-api-python-client = ">=2.130.0" +types-httplib2 = ">=0.22.0.2" +typing-extensions = ">=3.10.0" + [[package]] name = "google-auth" version = "2.26.2" @@ -6147,6 +6163,17 @@ files = [ {file = "types_awscrt-0.20.9.tar.gz", hash = "sha256:64898a2f4a2468f66233cb8c29c5f66de907cf80ba1ef5bb1359aef2f81bb521"}, ] +[[package]] +name = "types-httplib2" +version = "0.22.0.20240310" +description = "Typing stubs for httplib2" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-httplib2-0.22.0.20240310.tar.gz", hash = "sha256:1eda99fea18ec8a1dc1a725ead35b889d0836fec1b11ae6f1fe05440724c1d15"}, + {file = "types_httplib2-0.22.0.20240310-py3-none-any.whl", hash = "sha256:8cd706fc81f0da32789a4373a28df6f39e9d5657d1281db4d2fd22ee29e83661"}, +] + [[package]] name = "types-requests" version = "2.31.0.6" @@ -6830,4 +6857,4 @@ benchmark = ["agbenchmark"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e04a03e3f2663d3e54d5f4e6649cd4442cb00fd17ae5c68d06dbaadfc02ca309" +content-hash = "120a407f04865b3d781b5ad3523b4e1bb5c8ca10664055f9fbdcec5b31dd075c" diff --git a/forge/pyproject.toml b/forge/pyproject.toml index e99a2300dd..0eebb07e73 100644 --- a/forge/pyproject.toml +++ b/forge/pyproject.toml @@ -65,6 +65,7 @@ isort = "^5.13.1" pyright = "^1.1.364" pre-commit = "^3.3.3" boto3-stubs = { extras = ["s3"], version = "^1.33.6" } +google-api-python-client-stubs = "^1.26.0" types-requests = "^2.31.0.2" pytest = "^7.4.0" pytest-asyncio = "^0.21.1"