mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-06 21:34:05 -05:00
Regression benchmarks are meant to be run in pull-request. They can be launched in two flavors: * issue comment: using command like "/bench --backend cpu" * adding a label: `bench-perfs-cpu` or `bench-perfs-gpu` Benchmark definitions are written in TOML and located at ci/regression.toml. While not exhaustive, it can be easily modified by reading the embbeded documentation. "/bench" commands are parsed by a Python script located at ci/perf_regression.py. This script produces output files that contains cargo commands and a shell script generating custom environment variables. The Python script and generated files are meant to be used only by the workflow benchmark_perf_regression.yml.
97 lines
2.4 KiB
Python
97 lines
2.4 KiB
Python
"""
|
|
hardware_finder
|
|
---------------
|
|
|
|
This script parses ci/slab.toml file to find the hardware name associated with a given pair of backend and a profile name.
|
|
"""
|
|
|
|
import argparse
|
|
import enum
|
|
import pathlib
|
|
import sys
|
|
import tomllib
|
|
from typing import Any
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"backend",
|
|
choices=["aws", "hyperstack"],
|
|
help="Backend instance provider",
|
|
)
|
|
parser.add_argument(
|
|
"profile",
|
|
help="Instance profile name",
|
|
)
|
|
|
|
SLAB_FILE = pathlib.Path("ci/slab.toml")
|
|
|
|
|
|
class Backend(enum.StrEnum):
|
|
Aws = "aws"
|
|
Hyperstack = "hyperstack"
|
|
Hpu = "hpu" # Only v80 is supported for now
|
|
|
|
@staticmethod
|
|
def from_str(label):
|
|
match label.lower():
|
|
case "aws":
|
|
return Backend.Aws
|
|
case "hyperstack":
|
|
return Backend.Hyperstack
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
|
|
def parse_toml_file(path):
|
|
"""
|
|
Parse TOML file.
|
|
|
|
:param path: path to TOML file
|
|
:return: file content as :class:`dict`
|
|
"""
|
|
try:
|
|
return tomllib.loads(pathlib.Path(path).read_text())
|
|
except tomllib.TOMLDecodeError as err:
|
|
raise RuntimeError(f"failed to parse definition file (error: {err})")
|
|
|
|
|
|
def find_hardware_name(config_file: dict[str, Any], backend: Backend, profile: str):
|
|
"""
|
|
Find hardware name associated with :class:`Backend` and :class:`str` profile name.
|
|
|
|
:param config_file: parsed slab.toml file
|
|
:param backend: backend name
|
|
:param profile: profile name
|
|
|
|
:return: hardware name as :class:`str`
|
|
"""
|
|
try:
|
|
definition = config_file["backend"][backend.value][profile]
|
|
except KeyError:
|
|
section_name = f"backend.{backend.value}.{profile}"
|
|
raise KeyError(f"no definition found for `[{section_name}]` in {SLAB_FILE}")
|
|
|
|
match backend:
|
|
case Backend.Aws:
|
|
return definition["instance_type"]
|
|
case Backend.Hyperstack:
|
|
return definition["flavor_name"]
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
parsed_toml = parse_toml_file(SLAB_FILE)
|
|
backend = Backend.from_str(args.backend)
|
|
try:
|
|
hardware_name = find_hardware_name(parsed_toml, backend, args.profile)
|
|
except Exception as err:
|
|
print(
|
|
f"failed to find hardware name for ({args.backend}, {args.profile}): {err}"
|
|
)
|
|
sys.exit(1)
|
|
else:
|
|
print(hardware_name)
|