mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-14 09:08:06 -05:00
Compare commits
1 Commits
main
...
dt/ci/erc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d5a36c4c5 |
@@ -279,9 +279,10 @@ class ErrorFailureProbability(enum.IntEnum):
|
||||
return self.to_str()
|
||||
|
||||
|
||||
class BenchType(enum.Enum):
|
||||
Latency = 0
|
||||
Throughput = 1
|
||||
class BenchType(enum.StrEnum):
|
||||
Latency = "Latency"
|
||||
Throughput = "Throughput"
|
||||
Both = "Both"
|
||||
|
||||
@staticmethod
|
||||
def from_str(bench_type):
|
||||
@@ -290,10 +291,24 @@ class BenchType(enum.Enum):
|
||||
return BenchType.Latency
|
||||
case "throughput":
|
||||
return BenchType.Throughput
|
||||
case "both":
|
||||
return BenchType.Both
|
||||
case _:
|
||||
raise ValueError(f"BenchType '{bench_type}' not supported")
|
||||
|
||||
|
||||
class BenchSubset(enum.StrEnum):
|
||||
Erc20 = "erc20"
|
||||
|
||||
@staticmethod
|
||||
def from_str(bench_subset):
|
||||
match bench_subset.lower():
|
||||
case "erc20":
|
||||
return BenchSubset.Erc20
|
||||
case _:
|
||||
raise ValueError(f"BenchSubset '{bench_subset}' not supported")
|
||||
|
||||
|
||||
class ParamsDefinition:
|
||||
"""
|
||||
Represents a parameter definition for specific cryptographic settings.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
from benchmark_specs import Backend, BenchType, Layer, PBSKind
|
||||
from benchmark_specs import Backend, BenchSubset, BenchType, Layer, PBSKind
|
||||
|
||||
|
||||
class UserConfig:
|
||||
@@ -33,6 +33,12 @@ class UserConfig:
|
||||
|
||||
self.bench_type = BenchType.from_str(input_args.bench_type.lower())
|
||||
|
||||
self.bench_subset = (
|
||||
BenchSubset.from_str(input_args.bench_subset)
|
||||
if input_args.bench_subset
|
||||
else None
|
||||
)
|
||||
|
||||
self.layer = Layer.from_str(input_args.layer.lower())
|
||||
self.pbs_kind = PBSKind.from_str(input_args.pbs_kind)
|
||||
self.grouping_factor = input_args.grouping_factor
|
||||
|
||||
@@ -228,6 +228,9 @@ class PostgreConnector:
|
||||
filters.append("test.name NOT SIMILAR TO '%::throughput::%'")
|
||||
case BenchType.Throughput:
|
||||
filters.append("test.name LIKE '%::throughput::%'")
|
||||
case BenchType.Both:
|
||||
# No need to add a filter.
|
||||
pass
|
||||
|
||||
select_parts = (
|
||||
"SELECT",
|
||||
|
||||
@@ -25,7 +25,7 @@ import formatters.core
|
||||
import formatters.hlapi
|
||||
import formatters.integer
|
||||
import regression
|
||||
from benchmark_specs import BenchType, Layer, OperandType, RustType
|
||||
from benchmark_specs import BenchSubset, BenchType, Layer, OperandType, RustType
|
||||
from formatters.common import BenchArray, CSVFormatter, MarkdownFormatter, SVGFormatter
|
||||
|
||||
import utils
|
||||
@@ -129,10 +129,18 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--bench-type",
|
||||
dest="bench_type",
|
||||
choices=["latency", "throughput"],
|
||||
choices=["latency", "throughput", "both"],
|
||||
default="latency",
|
||||
help="Type of benchmark to filter against",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-subset",
|
||||
dest="bench_subset",
|
||||
choices=[
|
||||
"erc20",
|
||||
],
|
||||
help="Subset of benchmarks to filter against, dedicated formatting will be applied",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--regression-profiles",
|
||||
dest="regression_profiles",
|
||||
@@ -269,14 +277,20 @@ def perform_hardware_comparison(
|
||||
)
|
||||
|
||||
|
||||
def get_formatter(layer: Layer):
|
||||
def get_formatter(layer: Layer, bench_subset: BenchSubset = None):
|
||||
match layer:
|
||||
case Layer.Integer:
|
||||
return formatters.integer.IntegerFormatter
|
||||
case Layer.CoreCrypto:
|
||||
return formatters.core.CoreFormatter
|
||||
case Layer.HLApi:
|
||||
return formatters.hlapi.HlApiFormatter
|
||||
formatter = formatters.hlapi.HlApiFormatter
|
||||
if bench_subset:
|
||||
match bench_subset:
|
||||
case BenchSubset.Erc20:
|
||||
formatter = formatters.hlapi.Erc20Formatter
|
||||
|
||||
return formatter
|
||||
case _:
|
||||
raise NotImplementedError(f"layer '{layer}' not supported yet")
|
||||
|
||||
@@ -286,6 +300,7 @@ def perform_data_extraction(
|
||||
layer: Layer,
|
||||
operand_type: OperandType,
|
||||
output_filename: str,
|
||||
bench_subset: BenchSubset = None,
|
||||
):
|
||||
"""
|
||||
Extracts, formats, and processes benchmark data for a specified operand type and
|
||||
@@ -308,8 +323,12 @@ def perform_data_extraction(
|
||||
:return: Generic formatted arrays
|
||||
:rtype: list[BenchArray]
|
||||
"""
|
||||
operation_filter = [bench_subset.value] if bench_subset else None
|
||||
|
||||
try:
|
||||
res = conn.fetch_benchmark_data(user_config, operand_type)
|
||||
res = conn.fetch_benchmark_data(
|
||||
user_config, operand_type, operation_filter=operation_filter
|
||||
)
|
||||
except RuntimeError as err:
|
||||
print(f"Failed to fetch benchmark data: {err}")
|
||||
sys.exit(2)
|
||||
@@ -319,8 +338,10 @@ def perform_data_extraction(
|
||||
conversion_func = utils.convert_latency_value_to_readable_text
|
||||
case BenchType.Throughput:
|
||||
conversion_func = utils.convert_throughput_value_to_readable_text
|
||||
case BenchType.Both:
|
||||
conversion_func = None
|
||||
|
||||
generic_formatter_class = get_formatter(layer)
|
||||
generic_formatter_class = get_formatter(layer, bench_subset=bench_subset)
|
||||
generic_formatter = generic_formatter_class(
|
||||
layer, user_config.backend, user_config.pbs_kind, user_config.grouping_factor
|
||||
)
|
||||
@@ -335,12 +356,16 @@ def perform_data_extraction(
|
||||
file_suffix = ""
|
||||
filename = utils.append_suffix_to_filename(output_filename, file_suffix, ".csv")
|
||||
|
||||
utils.write_to_csv(
|
||||
CSVFormatter(layer, user_config.backend, user_config.pbs_kind).generate_csv(
|
||||
formatted_results
|
||||
),
|
||||
filename,
|
||||
)
|
||||
try:
|
||||
utils.write_to_csv(
|
||||
CSVFormatter(layer, user_config.backend, user_config.pbs_kind).generate_csv(
|
||||
formatted_results
|
||||
),
|
||||
filename,
|
||||
)
|
||||
except NotImplementedError as err:
|
||||
# Ignore this error if a formatter does not support CSV generation.
|
||||
print(f"CSV generation not supported (error: {err})")
|
||||
|
||||
generic_arrays = generic_formatter.generate_array(
|
||||
formatted_results,
|
||||
@@ -404,6 +429,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
user_config = config.UserConfig(args)
|
||||
layer = user_config.layer
|
||||
bench_subset = user_config.bench_subset
|
||||
|
||||
if args.generate_svg_from_file:
|
||||
generate_svg_from_file(user_config, layer, args.generate_svg_from_file)
|
||||
@@ -454,12 +480,18 @@ if __name__ == "__main__":
|
||||
print("Markdown generation is not supported with comparisons")
|
||||
continue
|
||||
|
||||
if layer == Layer.CoreCrypto and operand_type == OperandType.PlainText:
|
||||
if (
|
||||
layer == Layer.CoreCrypto or (layer == Layer.HLApi and bench_subset)
|
||||
) and operand_type == OperandType.PlainText:
|
||||
continue
|
||||
|
||||
file_suffix = f"_{operand_type.lower()}"
|
||||
arrays = perform_data_extraction(
|
||||
user_config, layer, operand_type, user_config.output_file
|
||||
user_config,
|
||||
layer,
|
||||
operand_type,
|
||||
user_config.output_file,
|
||||
bench_subset=bench_subset,
|
||||
)
|
||||
generate_files_from_arrays(
|
||||
arrays,
|
||||
|
||||
@@ -159,7 +159,7 @@ class GenericFormatter:
|
||||
self.requested_grouping_factor = grouping_factor
|
||||
|
||||
def format_data(
|
||||
self, data: dict[BenchDetails : list[int]], conversion_func: Callable
|
||||
self, data: dict[BenchDetails : list[int]], conversion_func: Callable = None
|
||||
):
|
||||
"""
|
||||
Formats data based on the specified layer and applies a conversion function to
|
||||
@@ -175,7 +175,7 @@ class GenericFormatter:
|
||||
:type data: dict[BenchDetails : list[int]]
|
||||
:param conversion_func: A callable function that will be applied to transform
|
||||
the data values based on the specific layer requirements.
|
||||
:type conversion_func: Callable
|
||||
:type conversion_func: Callable, optional
|
||||
|
||||
:return: The formatted data results after applying layer and conversion logic.
|
||||
:rtype: Any
|
||||
@@ -269,10 +269,9 @@ class CSVFormatter(GenericFormatter):
|
||||
case Layer.CoreCrypto:
|
||||
headers = ["Operation \\ Parameters set", *headers_values]
|
||||
case _:
|
||||
print(
|
||||
raise NotImplementedError(
|
||||
f"tfhe-rs layer '{self.layer}' currently not supported for CSV writing"
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
csv_data = [headers]
|
||||
csv_data.extend(
|
||||
@@ -374,6 +373,14 @@ class SVGFormatter(GenericFormatter):
|
||||
for row_idx, type_ident in enumerate(headers):
|
||||
curr_x = op_name_col_width + row_idx * per_timing_col_width
|
||||
|
||||
header_one_row_span = self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
row_height / 2,
|
||||
type_ident,
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
)
|
||||
|
||||
match layer:
|
||||
case Layer.Integer:
|
||||
if type_ident.startswith("FheUint"):
|
||||
@@ -399,28 +406,16 @@ class SVGFormatter(GenericFormatter):
|
||||
]
|
||||
)
|
||||
else: # Backends comparison (CPU, GPU, HPU)
|
||||
header_elements.append(
|
||||
self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
row_height / 2,
|
||||
type_ident,
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
)
|
||||
)
|
||||
header_elements.append(header_one_row_span)
|
||||
case Layer.CoreCrypto:
|
||||
header_elements.append(
|
||||
# Core_crypto arrays contains only ciphertext modulus size as headers
|
||||
self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
row_height / 2,
|
||||
type_ident,
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
)
|
||||
)
|
||||
# Core_crypto arrays contains only ciphertext modulus size as headers
|
||||
header_elements.append(header_one_row_span)
|
||||
case Layer.HLApi:
|
||||
header_elements.append(header_one_row_span)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
f"svg header row generation not supported for '{layer}' layer"
|
||||
)
|
||||
|
||||
return header_elements
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import collections
|
||||
|
||||
from benchmark_specs import BenchDetails
|
||||
from formatters.common import GenericFormatter
|
||||
from benchmark_specs import Backend, BenchDetails, BenchType
|
||||
from formatters.common import BenchArray, GenericFormatter
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
class HlApiFormatter(GenericFormatter):
|
||||
"""
|
||||
Formatter for arithmetic operations benchmarks.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _format_data(data: dict[BenchDetails : list[int]], conversion_func):
|
||||
formatted = collections.defaultdict(
|
||||
@@ -28,3 +34,56 @@ class HlApiFormatter(GenericFormatter):
|
||||
formatted[test_name][bit_width] = value
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
TRANSFER_IMPLEM_COLUMN_HEADER = "Transfer implementation"
|
||||
|
||||
|
||||
class Erc20Formatter(HlApiFormatter):
|
||||
"""
|
||||
Formatter for ERC20 benchmarks.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _format_data(data: dict[BenchDetails : list[int]], *args):
|
||||
formatted = collections.defaultdict(
|
||||
lambda: {
|
||||
BenchType.Latency: "N/A",
|
||||
BenchType.Throughput: "N/A",
|
||||
}
|
||||
)
|
||||
|
||||
for details, timings in data.items():
|
||||
name_parts = details.operation_name.split("::")
|
||||
test_name = name_parts[name_parts.index("transfer") + 1]
|
||||
if "throughput" in name_parts:
|
||||
bench_type = BenchType.Throughput
|
||||
conversion_func = utils.convert_throughput_value_to_readable_text
|
||||
else:
|
||||
bench_type = BenchType.Latency
|
||||
conversion_func = utils.convert_latency_value_to_readable_text
|
||||
|
||||
# For now ERC20 benchmarks are only made on 64-bit ciphertexts.
|
||||
value = conversion_func(timings[-1])
|
||||
formatted[test_name][bench_type] = value
|
||||
|
||||
return formatted
|
||||
|
||||
def _generate_arrays(self, data, *args, **kwargs):
|
||||
first_column_header = TRANSFER_IMPLEM_COLUMN_HEADER
|
||||
|
||||
match self.backend:
|
||||
case Backend.HPU:
|
||||
op_names = ["whitepaper", "hpu_optim", "hpu_simd"]
|
||||
case _:
|
||||
op_names = ["whitepaper", "no_cmux", "overflow"]
|
||||
|
||||
result_lines = []
|
||||
for op_name in op_names:
|
||||
line = {first_column_header: op_name}
|
||||
line.update({str(bench_type): v for bench_type, v in data[op_name].items()})
|
||||
result_lines.append(line)
|
||||
|
||||
return [
|
||||
BenchArray(result_lines, self.layer),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user