Files
tfhe-rs/ci/webdriver.py

523 lines
15 KiB
Python

"""
webdriver
---------
Script to handle tests and benchmarks for client-side tfhe-rs WASM code.
"""
import argparse
import dataclasses
import datetime
import enum
import json
import os
import pathlib
import signal
import socket
import subprocess
import sys
import threading
import time
from bs4 import BeautifulSoup
from selenium import webdriver
from selenium.common.exceptions import TimeoutException
from selenium.webdriver import Keys
from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.common.by import By
from selenium.webdriver.firefox.options import Options as FirefoxOptions
from selenium.webdriver.firefox.service import Service as FirefoxService
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--address",
dest="address",
default="localhost",
help="Address to testing Node server",
)
parser.add_argument(
"-p",
"--port",
dest="port",
default=3000,
type=int,
help="Port to testing Node server",
)
parser.add_argument(
"-k",
"--browser-kind",
dest="browser_kind",
choices=["chrome", "firefox"],
required=True,
help="Path to web driver file",
)
parser.add_argument(
"-b",
"--browser-path",
dest="browser_path",
required=True,
help="Path to browser file",
)
parser.add_argument(
"-d",
"--driver-path",
dest="driver_path",
required=True,
help="Path to web driver file",
)
parser.add_argument(
"--index-path",
dest="index_path",
default="tfhe/web_wasm_parallel_tests/index.html",
help="Path to HTML index file containing all the tests/benchmarks",
)
parser.add_argument(
"--id-pattern",
dest="id_filter_pattern",
help="Pattern to use to filter HTML button ID displayed on web page",
)
parser.add_argument(
"--value-pattern",
dest="value_filter_pattern",
help="Pattern to use to filter HTML button value displayed on web page",
)
parser.add_argument(
"-f",
"--fail-fast",
dest="fail_fast",
action="store_true",
help="Exit on first failed test",
)
parser.add_argument(
"--server-cmd",
dest="server_cmd",
help="Command to execute to launch web server in the background",
)
parser.add_argument(
"--server-workdir",
dest="server_workdir",
help="Path to working directory to launch web server",
)
class BrowserKind(enum.Enum):
"""
Kind of browsers currently supported
"""
chrome = 1
firefox = 2
class Driver:
"""
Representation of a web driver relying on Selenium.
"""
def __init__(self, browser_path, driver_path, browser_kind, threaded_logs=False):
"""
:param browser_path: path to binary web browser as :class:`str`
:param driver_path: path to binary web driver as :class:`str`
:param browser_kind: :class:`BrowserKind`
:param threaded_logs: launch a thread to display log in parallel
"""
self.browser_path = browser_path
self.driver_path = driver_path
self._is_threaded_logs = threaded_logs
self._log_thread = None
self.browser_kind = browser_kind
match self.browser_kind:
case BrowserKind.chrome:
self.options = ChromeOptions()
if os.getuid() == 0:
# If user ID is root then driver needs to run in no-sandbox mode.
print(
"Script is running as root, running browser with --no-sandbox for compatibility"
)
self.options.add_argument("--no-sandbox")
case BrowserKind.firefox:
self.options = FirefoxOptions()
self.options.binary_location = self.browser_path
self.options.add_argument("--headless")
self._driver = None
self.shutting_down = False
def get_driver(self):
if self._driver is None:
match self.browser_kind:
case BrowserKind.chrome:
driver_service = ChromeService(self.driver_path)
self.options.set_capability("goog:loggingPrefs", {"browser": "ALL"})
self._driver = webdriver.Chrome(
service=driver_service, options=self.options
)
if self._is_threaded_logs:
self._log_thread = threading.Thread(target=self._threaded_logs)
case BrowserKind.firefox:
driver_service = FirefoxService(self.driver_path)
self.options.log.level = "trace"
self.options.enable_bidi = True
self._driver = webdriver.Firefox(
service=driver_service, options=self.options
)
self._driver.script.add_console_message_handler(
self._on_console_logs
)
case _:
print(
f"{self.browser_kind.name.capitalize()} browser driver is not supported"
)
sys.exit(1)
if self._log_thread:
self._log_thread.start()
return self._driver
def get_page(self, server_url, timeout_seconds=10):
dr = self.get_driver()
dr.get(server_url)
self.wait_for_page_load(self.get_waiter(timeout_seconds))
def get_waiter(self, timeout):
return WebDriverWait(self.get_driver(), timeout)
def wait_for_page_load(self, waiter):
waiter.until(
lambda d: d.execute_script("return document.readyState") == "complete"
)
def wait_for_button(self, waiter, element_id):
return waiter.until(EC.element_to_be_clickable((By.ID, element_id)))
def wait_for_selection(self, waiter, element):
return waiter.until(EC.element_to_be_selected(element))
def find_element(self, element_id):
return self.get_driver().find_element(By.ID, element_id)
def _on_console_logs(self, log):
"""
Callback used for retrieving console log using BiDi protocol reling on websocket
"""
# Filter out useless message
if "using deprecated parameters" in log.text:
return
print(f"{log.level.upper()}: {log.text}")
def print_log(self, log_type):
logs = self.get_driver().get_log(log_type)
for log in logs:
# Filter out useless message
if "using deprecated parameters" in log["message"]:
continue
# String pattern is `<server url> <line:col> "<log message>"`
# We only care for <log message> part.
content = log["message"].split(maxsplit=2)[-1].strip('"')
print(f"{log['level']}: {content}")
def _threaded_logs(self):
while not self.shutting_down:
self.print_log("browser")
time.sleep(0.2)
def refresh(self):
match self.browser_kind:
case BrowserKind.chrome:
self.get_driver().refresh()
case BrowserKind.firefox:
# Need to force refresh in Firefox to avoid script caching by web workers
self.get_driver().find_element(By.TAG_NAME, "body").send_keys(
Keys.CONTROL + Keys.SHIFT + "R"
)
def quit(self):
self.shutting_down = True
if self._log_thread:
self._log_thread.join()
if self._driver:
self.get_driver().quit()
@dataclasses.dataclass
class UseCase:
"""
Use case extracted from an HTML element.
"""
id: str
value: str
timeout_seconds: int
class Cases:
"""
Container for :class:`UseCase`.
"""
def __init__(self):
self._cases = []
def __iter__(self):
return self._cases.__iter__()
def append(self, use_case):
self._cases.append(use_case)
def _filter(self, field, pattern):
return [case for case in self._cases if pattern in getattr(case, field)]
def filter_by_id(self, pattern):
"""
Filter use cases by their HTML `id` attribute.
:param pattern: :class:`str` that would be included in `id`
:return: :class:`list` comprehension of :class:`UseCase`
"""
return self._filter("id", pattern)
def filter_by_value(self, pattern):
"""
Filter use cases by their HTML `value` attribute.
:param pattern: :class:`str` that would be included in `value`
:return: :class:`list` comprehension of :class:`UseCase`
"""
return self._filter("value", pattern)
def parse_html_index(filepath):
"""
Parse HTML index containing all the element that can be handled by a webdriver.
Each supported element will be turned into a :class:`UseCase` which will be
appended to a container of :class:`Cases`.
:param filepath: path to index file as :class:`pathlib.Path`
:return: :class:`Cases`
"""
cases = Cases()
soup = BeautifulSoup(filepath.read_text(), "html.parser")
for tag in soup.find_all("input"):
if tag["type"] != "button":
continue
case_timeout_seconds = int(tag.get("max", "60"))
cases.append(UseCase(tag["id"], tag["value"], case_timeout_seconds))
return cases
def run_case(driver, case):
"""
Run test or benchmark case using a web driver.
If case is too long to run, it will raise an :exec:`TimeoutException`.
:param driver: :class:`Driver`
:param case: :class:`UseCase`
:return: :class:`dict` of benchmark results if `case` is benchmarks otherwise `None`
"""
page_waiter = driver.get_waiter(10)
test_waiter = driver.get_waiter(case.timeout_seconds)
print("[driver] Wait for page to load")
driver.wait_for_page_load(page_waiter)
print(f"[driver] Wait for HTML button to be clickable (id: {case.id})")
button = driver.wait_for_button(page_waiter, case.id)
button.click()
checkbox_id = "testSuccess"
checkbox = driver.find_element(checkbox_id)
try:
print("[driver] Wait for result checkbox to be checked")
driver.wait_for_selection(test_waiter, checkbox)
except TimeoutException:
driver.refresh()
raise TimeoutException(
f"timed out after {case.timeout_seconds} seconds waiting for result checkbox to be checked"
)
benchmark_results = driver.find_element("benchmarkResults").get_attribute("value")
driver.refresh()
return json.loads(benchmark_results) if benchmark_results else None
def dump_benchmark_results(results, browser_kind):
"""
Dump as JSON benchmark results into a file.
If `results` is an empty dict then this function is a no-op.
:param results: benchmark results as :class:`dict`
:param browser_kind: browser as :class:`BrowserKind`
"""
if results:
results = {
key.replace("mean", "_".join((browser_kind.name, "mean"))): val
for key, val in results.items()
}
pathlib.Path("tfhe-benchmark/wasm_benchmark_results.json").write_text(json.dumps(results))
def start_web_server(
command, working_directory, server_address, server_port, startup_timeout_seconds=30
):
"""
Start web server with custom command as a subprocess.
:param command: command to start the server as :class:`str`
:param working_directory: path to directory to move before running `command`
:param server_address: web server address
:param server_port: web server port as :class:`int`
:param startup_timeout_seconds: duration in seconds to let server start up
:return: :class:`subprocess.Popen`
"""
try:
sock = socket.create_connection((server_address, server_port), timeout=2)
except (TimeoutError, ConnectionRefusedError):
# Nothing is alive at this URL, ignoring exception
pass
else:
sock.close()
raise ConnectionError(
f"address and port already in use at ({server_address}, {server_port})"
)
proc = subprocess.Popen(
command.split(),
cwd=working_directory,
stdout=subprocess.DEVNULL,
start_new_session=True,
)
print("Starting web server")
timeout_seconds = 0.5
start_date = datetime.datetime.now()
while (
datetime.datetime.now() - start_date
).total_seconds() < startup_timeout_seconds:
try:
sock = socket.create_connection(
(server_address, server_port), timeout=timeout_seconds
)
except TimeoutError:
pass
except ConnectionRefusedError:
time.sleep(timeout_seconds)
else:
sock.close()
break
else:
terminate_web_server(proc.pid)
raise TimeoutError(
f"timeout after {startup_timeout_seconds} seconds while waiting for web server"
)
return proc
def terminate_web_server(pid):
"""
Terminate web server process.
:param pid: process ID as :class:`int`
"""
# Killing process group since the server is a child process of
# spawned subprocess. Using a simple kill() would let the server
# alive even after exiting this program.
os.killpg(os.getpgid(pid), signal.SIGTERM)
def main():
args = parser.parse_args()
browser_kind = BrowserKind[args.browser_kind]
exit_code = 0
cases = parse_html_index(pathlib.Path(args.index_path))
if args.id_filter_pattern:
cases = cases.filter_by_id(args.id_filter_pattern)
elif args.value_filter_pattern:
cases = cases.filter_by_value(args.value_filter_pattern)
server_process = None
if args.server_cmd:
try:
server_process = start_web_server(
args.server_cmd, args.server_workdir, args.address, args.port
)
except Exception as err:
print(f"Failed to start web server (error: {err})")
sys.exit(1)
print("Starting web driver")
driver = Driver(
args.browser_path, args.driver_path, browser_kind, threaded_logs=True
)
driver.get_page(f"http://{args.address}:{args.port}", timeout_seconds=10)
failures = []
benchmark_results = {}
for case in cases:
try:
bench_res = run_case(driver, case)
print(f"SUCCESS: {case.id}\n")
if bench_res:
benchmark_results.update(bench_res)
except KeyboardInterrupt:
exit_code = 2
break
except Exception as error:
print(f"FAIL: {case.id} (reason: {error})\n")
if args.fail_fast:
print("Fail fast is enabled, exiting")
exit_code = 1
break
else:
failures.append(case.id)
dump_benchmark_results(benchmark_results, browser_kind)
# Close the browser
driver.quit()
if server_process:
print("Shutting down web server")
terminate_web_server(server_process.pid)
if failures:
exit_code = 1
print("Following tests have failed:")
for case_name in failures:
print(f"* {case_name}")
sys.exit(exit_code)
if __name__ == "__main__":
main()