mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
tests(sdk): e2e tests for llamaindex sdk (#207)
This commit is contained in:
@@ -32,6 +32,8 @@ steps:
|
||||
name: 'python:${_VERSION}'
|
||||
env:
|
||||
- TOOLBOX_URL=$_TOOLBOX_URL
|
||||
- TOOLBOX_VERSION=$_TOOLBOX_VERSION
|
||||
- GOOGLE_CLOUD_PROJECT=$PROJECT_ID
|
||||
args:
|
||||
- '-c'
|
||||
- >-
|
||||
@@ -40,4 +42,5 @@ steps:
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
substitutions:
|
||||
_VERSION: '3.13'
|
||||
_VERSION: '3.13'
|
||||
_TOOLBOX_VERSION: '0.0.4'
|
||||
@@ -39,7 +39,9 @@ test = [
|
||||
"pytest-asyncio==0.24.0",
|
||||
"pytest==8.3.3",
|
||||
"pytest-cov==6.0.0",
|
||||
"Pillow==10.4.0"
|
||||
"Pillow==10.4.0",
|
||||
"google-cloud-secret-manager==2.22.0",
|
||||
"google-cloud-storage==2.19.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
166
sdks/llamaindex/tests/conftest.py
Normal file
166
sdks/llamaindex/tests/conftest.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains pytest fixtures that are accessible from all
|
||||
files present in the same directory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
import google
|
||||
import pytest_asyncio
|
||||
from google.auth import compute_engine
|
||||
from google.cloud import secretmanager, storage
|
||||
|
||||
|
||||
#### Define Utility Functions
|
||||
def get_env_var(key: str) -> str:
|
||||
"""Gets environment variables."""
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
raise ValueError(f"Must set env var {key}")
|
||||
return value
|
||||
|
||||
|
||||
def access_secret_version(
|
||||
project_id: str, secret_id: str, version_id: str = "latest"
|
||||
) -> str:
|
||||
"""Accesses the payload of a given secret version from Secret Manager."""
|
||||
client = secretmanager.SecretManagerServiceClient()
|
||||
name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
|
||||
response = client.access_secret_version(request={"name": name})
|
||||
return response.payload.data.decode("UTF-8")
|
||||
|
||||
|
||||
def create_tmpfile(content: str) -> str:
|
||||
"""Creates a temporary file with the given content."""
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile:
|
||||
tmpfile.write(content)
|
||||
return tmpfile.name
|
||||
|
||||
|
||||
def download_blob(
|
||||
bucket_name: str, source_blob_name: str, destination_file_name: str
|
||||
) -> None:
|
||||
"""Downloads a blob from a GCS bucket."""
|
||||
storage_client = storage.Client()
|
||||
|
||||
bucket = storage_client.bucket(bucket_name)
|
||||
blob = bucket.blob(source_blob_name)
|
||||
blob.download_to_filename(destination_file_name)
|
||||
|
||||
print(f"Blob {source_blob_name} downloaded to {destination_file_name}.")
|
||||
|
||||
|
||||
def get_toolbox_binary_url(toolbox_version: str) -> str:
|
||||
"""Constructs the GCS path to the toolbox binary."""
|
||||
os_system = platform.system().lower()
|
||||
arch = (
|
||||
"arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64"
|
||||
)
|
||||
return f"v{toolbox_version}/{os_system}/{arch}/toolbox"
|
||||
|
||||
|
||||
def get_auth_token(client_id: str) -> str:
|
||||
"""Retrieves an authentication token"""
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials = compute_engine.IDTokenCredentials(
|
||||
request=request,
|
||||
target_audience=client_id,
|
||||
use_metadata_identity_endpoint=True,
|
||||
)
|
||||
if not credentials.valid:
|
||||
credentials.refresh(request)
|
||||
return credentials.token
|
||||
|
||||
|
||||
#### Define Fixtures
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def project_id() -> str:
|
||||
return get_env_var("GOOGLE_CLOUD_PROJECT")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def toolbox_version() -> str:
|
||||
return get_env_var("TOOLBOX_VERSION")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def tools_file_path(project_id: str) -> Generator[str]:
|
||||
"""Provides a temporary file path containing the tools manifest."""
|
||||
tools_manifest = access_secret_version(
|
||||
project_id=project_id, secret_id="sdk_testing_tools"
|
||||
)
|
||||
tools_file_path = create_tmpfile(tools_manifest)
|
||||
yield tools_file_path
|
||||
os.remove(tools_file_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def auth_token1(project_id: str) -> str:
|
||||
client_id = access_secret_version(
|
||||
project_id=project_id, secret_id="sdk_testing_client1"
|
||||
)
|
||||
return get_auth_token(client_id)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def auth_token2(project_id: str) -> str:
|
||||
client_id = access_secret_version(
|
||||
project_id=project_id, secret_id="sdk_testing_client2"
|
||||
)
|
||||
return get_auth_token(client_id)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]:
|
||||
"""Starts the toolbox server as a subprocess."""
|
||||
print("Downloading toolbox binary from gcs bucket...")
|
||||
source_blob_name = get_toolbox_binary_url(toolbox_version)
|
||||
download_blob("genai-toolbox", source_blob_name, "toolbox")
|
||||
print("Toolbox binary downloaded successfully.")
|
||||
try:
|
||||
print("Opening toolbox server process...")
|
||||
# Make toolbox executable
|
||||
os.chmod("toolbox", 0o700)
|
||||
# Run toolbox binary
|
||||
toolbox_server = subprocess.Popen(
|
||||
["./toolbox", "--tools_file", tools_file_path]
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
# Retry logic with a timeout
|
||||
for _ in range(5): # retries
|
||||
time.sleep(4)
|
||||
print("Checking if toolbox is successfully started...")
|
||||
if toolbox_server.poll() is None:
|
||||
print("Toolbox server started successfully.")
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Toolbox server failed to start after 5 retries.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.stderr.decode("utf-8"))
|
||||
print(e.stdout.decode("utf-8"))
|
||||
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e
|
||||
yield
|
||||
|
||||
# Clean up toolbox server
|
||||
toolbox_server.terminate()
|
||||
toolbox_server.wait()
|
||||
155
sdks/llamaindex/tests/test_e2e.py
Normal file
155
sdks/llamaindex/tests/test_e2e.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end tests for the toolbox SDK interacting with the toolbox server.
|
||||
|
||||
This file covers the following use cases:
|
||||
|
||||
1. Loading a tool.
|
||||
2. Loading a specific toolset.
|
||||
3. Loading the default toolset (contains all tools).
|
||||
4. Running a tool with no required auth, with auth provided.
|
||||
5. Running a tool with required auth:
|
||||
a. No auth provided.
|
||||
b. Wrong auth provided: The tool requires a different authentication
|
||||
than the one provided.
|
||||
c. Correct auth provided.
|
||||
6. Running a tool with a parameter that requires auth:
|
||||
a. No auth provided.
|
||||
b. Correct auth provided.
|
||||
c. Auth provided does not contain the required claim.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aiohttp import ClientResponseError
|
||||
|
||||
from toolbox_llamaindex_sdk.client import ToolboxClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("toolbox_server")
|
||||
class TestE2EClient:
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def toolbox(self):
|
||||
"""Provides a ToolboxClient instance for each test."""
|
||||
toolbox = ToolboxClient("http://localhost:5000")
|
||||
yield toolbox
|
||||
await toolbox.close()
|
||||
|
||||
#### Basic e2e tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_tool(self, toolbox):
|
||||
tool = await toolbox.load_tool("get-n-rows")
|
||||
response = await tool.acall(num_rows="2")
|
||||
result = response.raw_output["result"]
|
||||
|
||||
assert "row1" in result
|
||||
assert "row2" in result
|
||||
assert "row3" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_toolset_specific(self, toolbox):
|
||||
toolset = await toolbox.load_toolset("my-toolset")
|
||||
assert len(toolset) == 1
|
||||
assert toolset[0].metadata.name == "get-row-by-id"
|
||||
|
||||
toolset = await toolbox.load_toolset("my-toolset-2")
|
||||
assert len(toolset) == 2
|
||||
tool_names = ["get-n-rows", "get-row-by-id"]
|
||||
assert toolset[0].metadata.name in tool_names
|
||||
assert toolset[1].metadata.name in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_toolset_all(self, toolbox):
|
||||
toolset = await toolbox.load_toolset()
|
||||
assert len(toolset) == 5
|
||||
tool_names = [
|
||||
"get-n-rows",
|
||||
"get-row-by-id",
|
||||
"get-row-by-id-auth",
|
||||
"get-row-by-email-auth",
|
||||
"get-row-by-content-auth",
|
||||
]
|
||||
for tool in toolset:
|
||||
assert tool.metadata.name in tool_names
|
||||
|
||||
##### Auth tests
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="b/389574566")
|
||||
async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2):
|
||||
"""Tests running a tool that doesn't require auth, with auth provided."""
|
||||
tool = await toolbox.load_tool(
|
||||
"get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2}
|
||||
)
|
||||
response = await tool.acall(id="2")
|
||||
assert "row2" in response.raw_output["result"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tool_no_auth(self, toolbox):
|
||||
"""Tests running a tool requiring auth without providing auth."""
|
||||
tool = await toolbox.load_tool(
|
||||
"get-row-by-id-auth",
|
||||
)
|
||||
with pytest.raises(ClientResponseError, match="401, message='Unauthorized'"):
|
||||
await tool.acall(id="2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tool_wrong_auth(self, toolbox, auth_token2):
|
||||
"""Tests running a tool with incorrect auth."""
|
||||
toolbox.add_auth_token("my-test-auth", lambda: auth_token2)
|
||||
tool = await toolbox.load_tool(
|
||||
"get-row-by-id-auth",
|
||||
)
|
||||
# TODO: Fix error message (b/389577313)
|
||||
with pytest.raises(ClientResponseError, match="400, message='Bad Request'"):
|
||||
await tool.acall(id="2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tool_auth(self, toolbox, auth_token1):
|
||||
"""Tests running a tool with correct auth."""
|
||||
toolbox.add_auth_token("my-test-auth", lambda: auth_token1)
|
||||
tool = await toolbox.load_tool(
|
||||
"get-row-by-id-auth",
|
||||
)
|
||||
response = await tool.acall(id="2")
|
||||
assert "row2" in response.raw_output["result"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tool_param_auth_no_auth(self, toolbox):
|
||||
"""Tests running a tool with a param requiring auth, without auth."""
|
||||
tool = await toolbox.load_tool("get-row-by-email-auth")
|
||||
with pytest.raises(PermissionError, match="Login required"):
|
||||
await tool.acall()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tool_param_auth(self, toolbox, auth_token1):
|
||||
"""Tests running a tool with a param requiring auth, with correct auth."""
|
||||
tool = await toolbox.load_tool(
|
||||
"get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1}
|
||||
)
|
||||
response = await tool.acall()
|
||||
result = response.raw_output["result"]
|
||||
assert "row4" in result
|
||||
assert "row5" in result
|
||||
assert "row6" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1):
|
||||
"""Tests running a tool with a param requiring auth, with insufficient auth."""
|
||||
tool = await toolbox.load_tool(
|
||||
"get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1}
|
||||
)
|
||||
with pytest.raises(ClientResponseError, match="400, message='Bad Request'"):
|
||||
await tool.acall()
|
||||
Reference in New Issue
Block a user