mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
A simple driver.refresh() wouldn't refresh script cache for web workers. Thus leading to page not fully charged and having all test and benchmarks buttons in disabled state. That triggering timeout while browsing with Selenium. Console log print has also been curated for Chrome browser.
523 lines
15 KiB
Python
523 lines
15 KiB
Python
"""
|
|
webdriver
|
|
---------
|
|
|
|
Script to handle tests and benchmarks for client-side tfhe-rs WASM code.
|
|
"""
|
|
|
|
import argparse
|
|
import dataclasses
|
|
import datetime
|
|
import enum
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
from bs4 import BeautifulSoup
|
|
from selenium import webdriver
|
|
from selenium.common.exceptions import TimeoutException
|
|
from selenium.webdriver import Keys
|
|
from selenium.webdriver.chrome.options import Options as ChromeOptions
|
|
from selenium.webdriver.chrome.service import Service as ChromeService
|
|
from selenium.webdriver.common.by import By
|
|
from selenium.webdriver.firefox.options import Options as FirefoxOptions
|
|
from selenium.webdriver.firefox.service import Service as FirefoxService
|
|
from selenium.webdriver.support import expected_conditions as EC
|
|
from selenium.webdriver.support.ui import WebDriverWait
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-a",
|
|
"--address",
|
|
dest="address",
|
|
default="localhost",
|
|
help="Address to testing Node server",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--port",
|
|
dest="port",
|
|
default=3000,
|
|
type=int,
|
|
help="Port to testing Node server",
|
|
)
|
|
parser.add_argument(
|
|
"-k",
|
|
"--browser-kind",
|
|
dest="browser_kind",
|
|
choices=["chrome", "firefox"],
|
|
required=True,
|
|
help="Path to web driver file",
|
|
)
|
|
parser.add_argument(
|
|
"-b",
|
|
"--browser-path",
|
|
dest="browser_path",
|
|
required=True,
|
|
help="Path to browser file",
|
|
)
|
|
parser.add_argument(
|
|
"-d",
|
|
"--driver-path",
|
|
dest="driver_path",
|
|
required=True,
|
|
help="Path to web driver file",
|
|
)
|
|
parser.add_argument(
|
|
"--index-path",
|
|
dest="index_path",
|
|
default="tfhe/web_wasm_parallel_tests/index.html",
|
|
help="Path to HTML index file containing all the tests/benchmarks",
|
|
)
|
|
parser.add_argument(
|
|
"--id-pattern",
|
|
dest="id_filter_pattern",
|
|
help="Pattern to use to filter HTML button ID displayed on web page",
|
|
)
|
|
parser.add_argument(
|
|
"--value-pattern",
|
|
dest="value_filter_pattern",
|
|
help="Pattern to use to filter HTML button value displayed on web page",
|
|
)
|
|
parser.add_argument(
|
|
"-f",
|
|
"--fail-fast",
|
|
dest="fail_fast",
|
|
action="store_true",
|
|
help="Exit on first failed test",
|
|
)
|
|
parser.add_argument(
|
|
"--server-cmd",
|
|
dest="server_cmd",
|
|
help="Command to execute to launch web server in the background",
|
|
)
|
|
parser.add_argument(
|
|
"--server-workdir",
|
|
dest="server_workdir",
|
|
help="Path to working directory to launch web server",
|
|
)
|
|
|
|
|
|
class BrowserKind(enum.Enum):
|
|
"""
|
|
Kind of browsers currently supported
|
|
"""
|
|
|
|
chrome = 1
|
|
firefox = 2
|
|
|
|
|
|
class Driver:
|
|
"""
|
|
Representation of a web driver relying on Selenium.
|
|
"""
|
|
|
|
def __init__(self, browser_path, driver_path, browser_kind, threaded_logs=False):
|
|
"""
|
|
:param browser_path: path to binary web browser as :class:`str`
|
|
:param driver_path: path to binary web driver as :class:`str`
|
|
:param browser_kind: :class:`BrowserKind`
|
|
:param threaded_logs: launch a thread to display log in parallel
|
|
"""
|
|
self.browser_path = browser_path
|
|
self.driver_path = driver_path
|
|
|
|
self._is_threaded_logs = threaded_logs
|
|
self._log_thread = None
|
|
|
|
self.browser_kind = browser_kind
|
|
|
|
match self.browser_kind:
|
|
case BrowserKind.chrome:
|
|
self.options = ChromeOptions()
|
|
if os.getuid() == 0:
|
|
# If user ID is root then driver needs to run in no-sandbox mode.
|
|
print(
|
|
"Script is running as root, running browser with --no-sandbox for compatibility"
|
|
)
|
|
self.options.add_argument("--no-sandbox")
|
|
case BrowserKind.firefox:
|
|
self.options = FirefoxOptions()
|
|
|
|
self.options.binary_location = self.browser_path
|
|
self.options.add_argument("--headless")
|
|
|
|
self._driver = None
|
|
|
|
self.shutting_down = False
|
|
|
|
def get_driver(self):
|
|
if self._driver is None:
|
|
|
|
match self.browser_kind:
|
|
case BrowserKind.chrome:
|
|
driver_service = ChromeService(self.driver_path)
|
|
self.options.set_capability("goog:loggingPrefs", {"browser": "ALL"})
|
|
self._driver = webdriver.Chrome(
|
|
service=driver_service, options=self.options
|
|
)
|
|
if self._is_threaded_logs:
|
|
self._log_thread = threading.Thread(target=self._threaded_logs)
|
|
case BrowserKind.firefox:
|
|
driver_service = FirefoxService(self.driver_path)
|
|
self.options.log.level = "trace"
|
|
self.options.enable_bidi = True
|
|
self._driver = webdriver.Firefox(
|
|
service=driver_service, options=self.options
|
|
)
|
|
self._driver.script.add_console_message_handler(
|
|
self._on_console_logs
|
|
)
|
|
case _:
|
|
print(
|
|
f"{self.browser_kind.name.capitalize()} browser driver is not supported"
|
|
)
|
|
sys.exit(1)
|
|
|
|
if self._log_thread:
|
|
self._log_thread.start()
|
|
|
|
return self._driver
|
|
|
|
def get_page(self, server_url, timeout_seconds=10):
|
|
dr = self.get_driver()
|
|
dr.get(server_url)
|
|
self.wait_for_page_load(self.get_waiter(timeout_seconds))
|
|
|
|
def get_waiter(self, timeout):
|
|
return WebDriverWait(self.get_driver(), timeout)
|
|
|
|
def wait_for_page_load(self, waiter):
|
|
waiter.until(
|
|
lambda d: d.execute_script("return document.readyState") == "complete"
|
|
)
|
|
|
|
def wait_for_button(self, waiter, element_id):
|
|
return waiter.until(EC.element_to_be_clickable((By.ID, element_id)))
|
|
|
|
def wait_for_selection(self, waiter, element):
|
|
return waiter.until(EC.element_to_be_selected(element))
|
|
|
|
def find_element(self, element_id):
|
|
return self.get_driver().find_element(By.ID, element_id)
|
|
|
|
def _on_console_logs(self, log):
|
|
"""
|
|
Callback used for retrieving console log using BiDi protocol reling on websocket
|
|
"""
|
|
# Filter out useless message
|
|
if "using deprecated parameters" in log.text:
|
|
return
|
|
|
|
print(f"{log.level.upper()}: {log.text}")
|
|
|
|
def print_log(self, log_type):
|
|
logs = self.get_driver().get_log(log_type)
|
|
for log in logs:
|
|
# Filter out useless message
|
|
if "using deprecated parameters" in log["message"]:
|
|
continue
|
|
|
|
# String pattern is `<server url> <line:col> "<log message>"`
|
|
# We only care for <log message> part.
|
|
content = log["message"].split(maxsplit=2)[-1].strip('"')
|
|
print(f"{log['level']}: {content}")
|
|
|
|
def _threaded_logs(self):
|
|
while not self.shutting_down:
|
|
self.print_log("browser")
|
|
time.sleep(0.2)
|
|
|
|
def refresh(self):
|
|
match self.browser_kind:
|
|
case BrowserKind.chrome:
|
|
self.get_driver().refresh()
|
|
case BrowserKind.firefox:
|
|
# Need to force refresh in Firefox to avoid script caching by web workers
|
|
self.get_driver().find_element(By.TAG_NAME, "body").send_keys(
|
|
Keys.CONTROL + Keys.SHIFT + "R"
|
|
)
|
|
|
|
def quit(self):
|
|
self.shutting_down = True
|
|
if self._log_thread:
|
|
self._log_thread.join()
|
|
|
|
if self._driver:
|
|
self.get_driver().quit()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class UseCase:
|
|
"""
|
|
Use case extracted from an HTML element.
|
|
"""
|
|
|
|
id: str
|
|
value: str
|
|
timeout_seconds: int
|
|
|
|
|
|
class Cases:
|
|
"""
|
|
Container for :class:`UseCase`.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._cases = []
|
|
|
|
def __iter__(self):
|
|
return self._cases.__iter__()
|
|
|
|
def append(self, use_case):
|
|
self._cases.append(use_case)
|
|
|
|
def _filter(self, field, pattern):
|
|
return [case for case in self._cases if pattern in getattr(case, field)]
|
|
|
|
def filter_by_id(self, pattern):
|
|
"""
|
|
Filter use cases by their HTML `id` attribute.
|
|
|
|
:param pattern: :class:`str` that would be included in `id`
|
|
|
|
:return: :class:`list` comprehension of :class:`UseCase`
|
|
"""
|
|
return self._filter("id", pattern)
|
|
|
|
def filter_by_value(self, pattern):
|
|
"""
|
|
Filter use cases by their HTML `value` attribute.
|
|
|
|
:param pattern: :class:`str` that would be included in `value`
|
|
|
|
:return: :class:`list` comprehension of :class:`UseCase`
|
|
"""
|
|
return self._filter("value", pattern)
|
|
|
|
|
|
def parse_html_index(filepath):
|
|
"""
|
|
Parse HTML index containing all the element that can be handled by a webdriver.
|
|
Each supported element will be turned into a :class:`UseCase` which will be
|
|
appended to a container of :class:`Cases`.
|
|
|
|
:param filepath: path to index file as :class:`pathlib.Path`
|
|
|
|
:return: :class:`Cases`
|
|
"""
|
|
cases = Cases()
|
|
|
|
soup = BeautifulSoup(filepath.read_text(), "html.parser")
|
|
for tag in soup.find_all("input"):
|
|
if tag["type"] != "button":
|
|
continue
|
|
|
|
case_timeout_seconds = int(tag.get("max", "60"))
|
|
cases.append(UseCase(tag["id"], tag["value"], case_timeout_seconds))
|
|
|
|
return cases
|
|
|
|
|
|
def run_case(driver, case):
|
|
"""
|
|
Run test or benchmark case using a web driver.
|
|
If case is too long to run, it will raise an :exec:`TimeoutException`.
|
|
|
|
:param driver: :class:`Driver`
|
|
:param case: :class:`UseCase`
|
|
|
|
:return: :class:`dict` of benchmark results if `case` is benchmarks otherwise `None`
|
|
"""
|
|
page_waiter = driver.get_waiter(10)
|
|
test_waiter = driver.get_waiter(case.timeout_seconds)
|
|
|
|
print("[driver] Wait for page to load")
|
|
driver.wait_for_page_load(page_waiter)
|
|
|
|
print(f"[driver] Wait for HTML button to be clickable (id: {case.id})")
|
|
button = driver.wait_for_button(page_waiter, case.id)
|
|
button.click()
|
|
|
|
checkbox_id = "testSuccess"
|
|
checkbox = driver.find_element(checkbox_id)
|
|
try:
|
|
print("[driver] Wait for result checkbox to be checked")
|
|
driver.wait_for_selection(test_waiter, checkbox)
|
|
except TimeoutException:
|
|
driver.refresh()
|
|
raise TimeoutException(
|
|
f"timed out after {case.timeout_seconds} seconds waiting for result checkbox to be checked"
|
|
)
|
|
|
|
benchmark_results = driver.find_element("benchmarkResults").get_attribute("value")
|
|
|
|
driver.refresh()
|
|
|
|
return json.loads(benchmark_results) if benchmark_results else None
|
|
|
|
|
|
def dump_benchmark_results(results, browser_kind):
|
|
"""
|
|
Dump as JSON benchmark results into a file.
|
|
If `results` is an empty dict then this function is a no-op.
|
|
|
|
:param results: benchmark results as :class:`dict`
|
|
:param browser_kind: browser as :class:`BrowserKind`
|
|
"""
|
|
if results:
|
|
results = {
|
|
key.replace("mean", "_".join((browser_kind.name, "mean"))): val
|
|
for key, val in results.items()
|
|
}
|
|
pathlib.Path("tfhe/wasm_benchmark_results.json").write_text(json.dumps(results))
|
|
|
|
|
|
def start_web_server(
|
|
command, working_directory, server_address, server_port, startup_timeout_seconds=30
|
|
):
|
|
"""
|
|
Start web server with custom command as a subprocess.
|
|
|
|
:param command: command to start the server as :class:`str`
|
|
:param working_directory: path to directory to move before running `command`
|
|
:param server_address: web server address
|
|
:param server_port: web server port as :class:`int`
|
|
:param startup_timeout_seconds: duration in seconds to let server start up
|
|
|
|
:return: :class:`subprocess.Popen`
|
|
"""
|
|
try:
|
|
sock = socket.create_connection((server_address, server_port), timeout=2)
|
|
except (TimeoutError, ConnectionRefusedError):
|
|
# Nothing is alive at this URL, ignoring exception
|
|
pass
|
|
else:
|
|
sock.close()
|
|
raise ConnectionError(
|
|
f"address and port already in use at ({server_address}, {server_port})"
|
|
)
|
|
|
|
proc = subprocess.Popen(
|
|
command.split(),
|
|
cwd=working_directory,
|
|
stdout=subprocess.DEVNULL,
|
|
start_new_session=True,
|
|
)
|
|
|
|
print("Starting web server")
|
|
|
|
timeout_seconds = 0.5
|
|
start_date = datetime.datetime.now()
|
|
while (
|
|
datetime.datetime.now() - start_date
|
|
).total_seconds() < startup_timeout_seconds:
|
|
try:
|
|
sock = socket.create_connection(
|
|
(server_address, server_port), timeout=timeout_seconds
|
|
)
|
|
except TimeoutError:
|
|
pass
|
|
except ConnectionRefusedError:
|
|
time.sleep(timeout_seconds)
|
|
else:
|
|
sock.close()
|
|
break
|
|
else:
|
|
terminate_web_server(proc.pid)
|
|
raise TimeoutError(
|
|
f"timeout after {startup_timeout_seconds} seconds while waiting for web server"
|
|
)
|
|
|
|
return proc
|
|
|
|
|
|
def terminate_web_server(pid):
|
|
"""
|
|
Terminate web server process.
|
|
|
|
:param pid: process ID as :class:`int`
|
|
"""
|
|
# Killing process group since the server is a child process of
|
|
# spawned subprocess. Using a simple kill() would let the server
|
|
# alive even after exiting this program.
|
|
os.killpg(os.getpgid(pid), signal.SIGTERM)
|
|
|
|
|
|
def main():
|
|
args = parser.parse_args()
|
|
browser_kind = BrowserKind[args.browser_kind]
|
|
|
|
exit_code = 0
|
|
|
|
cases = parse_html_index(pathlib.Path(args.index_path))
|
|
if args.id_filter_pattern:
|
|
cases = cases.filter_by_id(args.id_filter_pattern)
|
|
elif args.value_filter_pattern:
|
|
cases = cases.filter_by_value(args.value_filter_pattern)
|
|
|
|
server_process = None
|
|
if args.server_cmd:
|
|
try:
|
|
server_process = start_web_server(
|
|
args.server_cmd, args.server_workdir, args.address, args.port
|
|
)
|
|
except Exception as err:
|
|
print(f"Failed to start web server (error: {err})")
|
|
sys.exit(1)
|
|
|
|
print("Starting web driver")
|
|
driver = Driver(
|
|
args.browser_path, args.driver_path, browser_kind, threaded_logs=True
|
|
)
|
|
|
|
driver.get_page(f"http://{args.address}:{args.port}", timeout_seconds=10)
|
|
|
|
failures = []
|
|
benchmark_results = {}
|
|
|
|
for case in cases:
|
|
try:
|
|
bench_res = run_case(driver, case)
|
|
print(f"SUCCESS: {case.id}\n")
|
|
if bench_res:
|
|
benchmark_results.update(bench_res)
|
|
except KeyboardInterrupt:
|
|
exit_code = 2
|
|
break
|
|
except Exception as error:
|
|
print(f"FAIL: {case.id} (reason: {error})\n")
|
|
if args.fail_fast:
|
|
print("Fail fast is enabled, exiting")
|
|
exit_code = 1
|
|
break
|
|
else:
|
|
failures.append(case.id)
|
|
|
|
dump_benchmark_results(benchmark_results, browser_kind)
|
|
|
|
# Close the browser
|
|
driver.quit()
|
|
|
|
if server_process:
|
|
print("Shutting down web server")
|
|
terminate_web_server(server_process.pid)
|
|
|
|
if failures:
|
|
exit_code = 1
|
|
print("Following tests have failed:")
|
|
for case_name in failures:
|
|
print(f"* {case_name}")
|
|
|
|
sys.exit(exit_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|