make web_search and google commands code-friendly

This commit is contained in:
Reinier van der Leer
2024-06-09 01:16:58 +02:00
parent 81bac301e8
commit 4ffc57fdaf
6 changed files with 148 additions and 117 deletions

View File

@@ -4,18 +4,25 @@ import time
from typing import Iterator
from duckduckgo_search import DDGS
from pydantic import BaseModel
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import ConfigurationError
from forge.utils.exceptions import ConfigurationError, InvalidArgumentError
DUCKDUCKGO_MAX_ATTEMPTS = 3
logger = logging.getLogger(__name__)
class SearchResult(BaseModel):
title: str
url: str
excerpt: str = ""
class WebSearchComponent(DirectiveProvider, CommandProvider):
"""Provides commands to search the web."""
@@ -61,7 +68,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
),
},
)
def web_search(self, query: str, num_results: int = 8) -> str:
def web_search(self, query: str, num_results: int = 8) -> list[SearchResult]:
"""Return the results of a Google search
Args:
@@ -71,13 +78,13 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
Returns:
str: The results of the search.
"""
if not query:
raise InvalidArgumentError("'query' must be non-empty")
search_results = []
attempts = 0
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
if not query:
return json.dumps(search_results)
search_results = DDGS().text(query, max_results=num_results)
if search_results:
@@ -86,23 +93,15 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
time.sleep(1)
attempts += 1
search_results = [
{
"title": r["title"],
"url": r["href"],
**({"exerpt": r["body"]} if r.get("body") else {}),
}
return [
SearchResult(
title=r["title"],
excerpt=r.get("body", ""),
url=r["href"],
)
for r in search_results
]
results = ("## Search results\n") + "\n\n".join(
f"### \"{r['title']}\"\n"
f"**URL:** {r['url']} \n"
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
for r in search_results
)
return self.safe_google_results(results)
@command(
["google"],
"Google Search",
@@ -121,7 +120,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
),
},
)
def google(self, query: str, num_results: int = 8) -> str | list[str]:
def google(self, query: str, num_results: int = 8) -> list[SearchResult]:
"""Return the results of a Google search using the official Google API
Args:
@@ -139,6 +138,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
# Get the Google API key and Custom Search Engine ID from the config file
api_key = self.legacy_config.google_api_key
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
assert api_key and custom_search_engine_id # checked in get_commands()
# Initialize the Custom Search API service
service = build("customsearch", "v1", developerKey=api_key)
@@ -151,44 +151,25 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
)
# Extract the search result items from the response
search_results = result.get("items", [])
# Create a list of only the URLs from the search results
search_results_links = [item["link"] for item in search_results]
return [
SearchResult(
title=r.get("title", ""),
excerpt=r.get("snippet", ""),
url=r["link"],
)
for r in result.get("items", [])
if "link" in r
]
except HttpError as e:
# Handle errors in the API call
error_details = json.loads(e.content.decode())
# Check if the error is related to an invalid or missing API key
if error_details.get("error", {}).get(
"code"
) == 403 and "invalid API key" in error_details.get("error", {}).get(
"message", ""
if error_details.get("error", {}).get("code") == 403 and (
"invalid API key" in error_details["error"].get("message", "")
):
raise ConfigurationError(
"The provided Google API key is invalid or missing."
)
raise
# google_result can be a list or a string depending on the search results
# Return the list of search result URLs
return self.safe_google_results(search_results_links)
def safe_google_results(self, results: str | list) -> str:
"""
Return the results of a Google search in a safe format.
Args:
results (str | list): The search results.
Returns:
str: The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps(
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
)
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message

35
forge/poetry.lock generated
View File

@@ -1969,13 +1969,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-api-python-client"
version = "2.129.0"
version = "2.132.0"
description = "Google API Client Library for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-api-python-client-2.129.0.tar.gz", hash = "sha256:984cc8cc8eb4923468b1926d2b8effc5b459a4dda3c845896eb87c153b28ef84"},
{file = "google_api_python_client-2.129.0-py2.py3-none-any.whl", hash = "sha256:d50f7e2dfdbb7fc2732f6a0cba1c54d7bb676390679526c6bb628c901e43ec86"},
{file = "google-api-python-client-2.132.0.tar.gz", hash = "sha256:d6340dc83b72d72333cee5d50f7dcfecbff66a8783164090e945f985ec4c374d"},
{file = "google_api_python_client-2.132.0-py2.py3-none-any.whl", hash = "sha256:cde87700bd4d37f39f5e940292c1c6cd0910990b5b01f50b1332a8cea38e8595"},
]
[package.dependencies]
@@ -1985,6 +1985,22 @@ google-auth-httplib2 = ">=0.2.0,<1.0.0"
httplib2 = ">=0.19.0,<1.dev0"
uritemplate = ">=3.0.1,<5"
[[package]]
name = "google-api-python-client-stubs"
version = "1.26.0"
description = "Type stubs for google-api-python-client"
optional = false
python-versions = "<4.0,>=3.7"
files = [
{file = "google_api_python_client_stubs-1.26.0-py3-none-any.whl", hash = "sha256:0614b0cef5beac43e6ab02418f07e64ee66dc99ae4e377d54a155ac261533987"},
{file = "google_api_python_client_stubs-1.26.0.tar.gz", hash = "sha256:f3b38b46f7b5cf4f6e7cc63ca554a2d23096d49c841f38b9ea553a5237074b56"},
]
[package.dependencies]
google-api-python-client = ">=2.130.0"
types-httplib2 = ">=0.22.0.2"
typing-extensions = ">=3.10.0"
[[package]]
name = "google-auth"
version = "2.26.2"
@@ -6147,6 +6163,17 @@ files = [
{file = "types_awscrt-0.20.9.tar.gz", hash = "sha256:64898a2f4a2468f66233cb8c29c5f66de907cf80ba1ef5bb1359aef2f81bb521"},
]
[[package]]
name = "types-httplib2"
version = "0.22.0.20240310"
description = "Typing stubs for httplib2"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-httplib2-0.22.0.20240310.tar.gz", hash = "sha256:1eda99fea18ec8a1dc1a725ead35b889d0836fec1b11ae6f1fe05440724c1d15"},
{file = "types_httplib2-0.22.0.20240310-py3-none-any.whl", hash = "sha256:8cd706fc81f0da32789a4373a28df6f39e9d5657d1281db4d2fd22ee29e83661"},
]
[[package]]
name = "types-requests"
version = "2.31.0.6"
@@ -6830,4 +6857,4 @@ benchmark = ["agbenchmark"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e04a03e3f2663d3e54d5f4e6649cd4442cb00fd17ae5c68d06dbaadfc02ca309"
content-hash = "120a407f04865b3d781b5ad3523b4e1bb5c8ca10664055f9fbdcec5b31dd075c"

View File

@@ -65,6 +65,7 @@ isort = "^5.13.1"
pyright = "^1.1.364"
pre-commit = "^3.3.3"
boto3-stubs = { extras = ["s3"], version = "^1.33.6" }
google-api-python-client-stubs = "^1.26.0"
types-requests = "^2.31.0.2"
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"