mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(ci): add backends comparison table generation for docs
This adds backends comparison in data extractor. It performs comparison on a fixed list (CPU, GPU, HPU) for 64 bits precision ciphertext as displayed in tfhe-rs public documentation. SVG table generation is automated via the documentation benchmark workflow.
This commit is contained in:
@@ -185,10 +185,12 @@ jobs:
|
||||
path: svg_tables
|
||||
merge-multiple: 'true'
|
||||
|
||||
# Perform best effort to copy SVG tables. If the copy fails or files don't exist, the PR will still be created.
|
||||
- name: Copy SVG tables to documentation location
|
||||
run: |
|
||||
cp -f svg_tables/*integer-benchmark*.svg "${PATH_TO_DOC_ASSETS}"
|
||||
cp -f svg_tables/*pbs-benchmark-tuniform*.svg "${PATH_TO_DOC_ASSETS}"
|
||||
cp -f svg_tables/*integer-benchmark*.svg "${PATH_TO_DOC_ASSETS}" 2>/dev/null
|
||||
cp -f svg_tables/*pbs-benchmark-tuniform*.svg "${PATH_TO_DOC_ASSETS}" 2>/dev/null
|
||||
cp -f svg_tables/cpu-gpu-hpu-integer-benchmark-fheuint64-tuniform-2m128-ciphertext.svg "${PATH_TO_DOC_ASSETS}" 2>/dev/null
|
||||
|
||||
- name: Get current date
|
||||
id: get-date
|
||||
|
||||
24
.github/workflows/generate_svg_common.yml
vendored
24
.github/workflows/generate_svg_common.yml
vendored
@@ -5,22 +5,20 @@ on:
|
||||
inputs:
|
||||
backend:
|
||||
type: string
|
||||
required: true
|
||||
hardware_name:
|
||||
type: string
|
||||
required: true
|
||||
layer:
|
||||
type: string
|
||||
required: true
|
||||
pbs_kind: # Valid values are 'classical', 'multi_bit' or 'any'
|
||||
type: string
|
||||
required: true
|
||||
grouping_factor: # Valid values are 2, 3, or 4
|
||||
type: string
|
||||
default: 4
|
||||
bench_type: # Valid values are 'latency', 'throughput'
|
||||
type: string
|
||||
required: true
|
||||
backend_comparison:
|
||||
type: boolean
|
||||
default: false
|
||||
time_span_days:
|
||||
type: string
|
||||
default: 60
|
||||
@@ -50,6 +48,7 @@ jobs:
|
||||
persist-credentials: 'false'
|
||||
|
||||
- name: Produce table from database
|
||||
if: inputs.backend_comparison == false
|
||||
run: |
|
||||
python3 -m pip install -r ci/data_extractor/requirements.txt
|
||||
python3 ci/data_extractor/src/data_extractor.py "${OUTPUT_FILENAME}" \
|
||||
@@ -76,6 +75,21 @@ jobs:
|
||||
DATA_EXTRACTOR_DATABASE_HOST: ${{ secrets.DATA_EXTRACTOR_DATABASE_HOST }}
|
||||
DATA_EXTRACTOR_DATABASE_PASSWORD: ${{ secrets.DATA_EXTRACTOR_DATABASE_PASSWORD }}
|
||||
|
||||
- name: Produce backends comparison table from database
|
||||
if: inputs.backend_comparison == true
|
||||
run: |
|
||||
python3 -m pip install -r ci/data_extractor/requirements.txt
|
||||
python3 ci/data_extractor/src/data_extractor.py "${OUTPUT_FILENAME}" \
|
||||
--generate-svg \
|
||||
--backend-comparison\
|
||||
--time-span-days "${TIME_SPAN}"
|
||||
env:
|
||||
OUTPUT_FILENAME: ${{ inputs.output_filename }}
|
||||
TIME_SPAN: ${{ inputs.time_span_days }}
|
||||
DATA_EXTRACTOR_DATABASE_USER: ${{ secrets.DATA_EXTRACTOR_DATABASE_USER }}
|
||||
DATA_EXTRACTOR_DATABASE_HOST: ${{ secrets.DATA_EXTRACTOR_DATABASE_HOST }}
|
||||
DATA_EXTRACTOR_DATABASE_PASSWORD: ${{ secrets.DATA_EXTRACTOR_DATABASE_PASSWORD }}
|
||||
|
||||
- name: Upload tables
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4
|
||||
with:
|
||||
|
||||
13
.github/workflows/generate_svgs.yml
vendored
13
.github/workflows/generate_svgs.yml
vendored
@@ -137,6 +137,19 @@ jobs:
|
||||
DATA_EXTRACTOR_DATABASE_HOST: ${{ secrets.DATA_EXTRACTOR_DATABASE_HOST }}
|
||||
DATA_EXTRACTOR_DATABASE_PASSWORD: ${{ secrets.DATA_EXTRACTOR_DATABASE_PASSWORD }}
|
||||
|
||||
backend-comparison-latency-table:
|
||||
name: generate_documentation_svgs/backend-comparison-latency-table
|
||||
uses: ./.github/workflows/generate_svg_common.yml
|
||||
if: inputs.generate-cpu-svgs && inputs.generate-gpu-svgs && inputs.generate-hpu-svgs
|
||||
with:
|
||||
backend_comparison: true
|
||||
time_span_days: ${{ inputs.time_span_days }}
|
||||
output_filename: cpu-gpu-hpu-integer-benchmark-fheuint64-tuniform-2m128-ciphertext
|
||||
secrets:
|
||||
DATA_EXTRACTOR_DATABASE_USER: ${{ secrets.DATA_EXTRACTOR_DATABASE_USER }}
|
||||
DATA_EXTRACTOR_DATABASE_HOST: ${{ secrets.DATA_EXTRACTOR_DATABASE_HOST }}
|
||||
DATA_EXTRACTOR_DATABASE_PASSWORD: ${{ secrets.DATA_EXTRACTOR_DATABASE_PASSWORD }}
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# PBS benchmarks tables
|
||||
# -----------------------------------------------------------
|
||||
|
||||
@@ -57,15 +57,52 @@ class RustType(enum.Enum):
|
||||
|
||||
FheUint2 = 2
|
||||
FheUint4 = 4
|
||||
FheUint6 = 6
|
||||
FheUint8 = 8
|
||||
FheUint10 = 10
|
||||
FheUint12 = 12
|
||||
FheUint14 = 14
|
||||
FheUint16 = 16
|
||||
FheUint32 = 32
|
||||
FheUint64 = 64
|
||||
FheUint128 = 128
|
||||
FheUint256 = 256
|
||||
FheUint512 = 512
|
||||
|
||||
@staticmethod
|
||||
def from_int(value):
|
||||
match value:
|
||||
case 2:
|
||||
return RustType.FheUint2
|
||||
case 4:
|
||||
return RustType.FheUint4
|
||||
case 6:
|
||||
return RustType.FheUint6
|
||||
case 8:
|
||||
return RustType.FheUint8
|
||||
case 10:
|
||||
return RustType.FheUint10
|
||||
case 12:
|
||||
return RustType.FheUint12
|
||||
case 14:
|
||||
return RustType.FheUint14
|
||||
case 16:
|
||||
return RustType.FheUint16
|
||||
case 32:
|
||||
return RustType.FheUint32
|
||||
case 64:
|
||||
return RustType.FheUint64
|
||||
case 128:
|
||||
return RustType.FheUint128
|
||||
case 256:
|
||||
return RustType.FheUint256
|
||||
case 512:
|
||||
return RustType.FheUint512
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
ALL_RUST_TYPES = [
|
||||
ALL_RUST_INTEGER_TYPES = [
|
||||
RustType.FheUint2,
|
||||
RustType.FheUint4,
|
||||
RustType.FheUint8,
|
||||
@@ -492,9 +529,19 @@ class BenchDetails:
|
||||
self.operation_name = parts[2] if parts[1] == "cuda" else parts[1]
|
||||
case Layer.HLApi:
|
||||
if parts[1] in ["cuda", "hpu"]:
|
||||
self.operation_name = "::".join(parts[2:-1])
|
||||
if "_PARAM_" in parts[-2]:
|
||||
# Case for arithmetic operations (add, sub, mul,...)
|
||||
self.operation_name = "::".join(parts[2:-2])
|
||||
else:
|
||||
# Case for higher-level operation (erc20 transfer, dex,...)
|
||||
self.operation_name = "::".join(parts[2:-1])
|
||||
else:
|
||||
self.operation_name = "::".join(parts[1:-1])
|
||||
if "_PARAM_" in parts[-2]:
|
||||
# Case for arithmetic operations (add, sub, mul,...)
|
||||
self.operation_name = "::".join(parts[1:-2])
|
||||
else:
|
||||
# Case for higher-level operation (erc20 transfer, dex,...)
|
||||
self.operation_name = "::".join(parts[1:-1])
|
||||
self.rust_type = parts[-1].partition("_mean")[0]
|
||||
case Layer.Shortint:
|
||||
self.operation_name = parts[1]
|
||||
|
||||
109
ci/data_extractor/src/comparison.py
Normal file
109
ci/data_extractor/src/comparison.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import copy
|
||||
|
||||
import config
|
||||
import connector
|
||||
|
||||
from benchmark_specs import Backend, Layer, RustType, OperandType, PBSKind
|
||||
from formatter import GenericFormatter, OPERATION_SIZE_COLUMN_HEADER
|
||||
import utils
|
||||
|
||||
DEFAULT_CPU_HARDWARE = "hpc7a.96xlarge"
|
||||
DEFAULT_GPU_HARDWARE = "n3-H100-SXM5x8"
|
||||
DEFAULT_HPU_HARDWARE = "hpu_x1"
|
||||
|
||||
|
||||
def perform_backends_comparison(
|
||||
conn: connector.PostgreConnector, user_config: config.UserConfig
|
||||
):
|
||||
"""
|
||||
Compares benchmark data for different backends (CPU, GPU, HPU) using the provided
|
||||
database connection and user configurations. The function fetches, processes, and
|
||||
formats benchmark data for each backend, considering specific configurations and
|
||||
hardware capabilities. Finally, it combines the formatted results into a unified
|
||||
array for comparison.
|
||||
|
||||
:param conn: A database connector used to fetch benchmark data from the data source.
|
||||
:type conn: PostgreConnector
|
||||
:param user_config: A user configuration copied and updated for each backend data fetch.
|
||||
:type user_config: UserConfig
|
||||
:return: A list containing a single formatted `BenchArray`, merging benchmark data
|
||||
across all backends for comparison.
|
||||
:rtype: list[BenchArray]
|
||||
"""
|
||||
user_config.layer = Layer.Integer
|
||||
conversion_func = utils.convert_latency_value_to_readable_text
|
||||
|
||||
backend_arrays = []
|
||||
|
||||
for backend, hardware_name in [
|
||||
(Backend.CPU, DEFAULT_CPU_HARDWARE),
|
||||
(Backend.GPU, DEFAULT_GPU_HARDWARE),
|
||||
(Backend.HPU, DEFAULT_HPU_HARDWARE),
|
||||
]:
|
||||
case_config = copy.deepcopy(user_config)
|
||||
case_config.backend = backend
|
||||
case_config.hardware = hardware_name
|
||||
if backend == Backend.GPU:
|
||||
case_config.pbs_kind = PBSKind.MultiBit
|
||||
case_config.grouping_factor = 4
|
||||
|
||||
print(f"Getting {backend} data")
|
||||
|
||||
res = conn.fetch_benchmark_data(case_config)
|
||||
|
||||
generic_formatter = GenericFormatter(
|
||||
case_config.layer,
|
||||
case_config.backend,
|
||||
case_config.pbs_kind,
|
||||
case_config.grouping_factor,
|
||||
)
|
||||
formatted_results = generic_formatter.format_data(
|
||||
res,
|
||||
conversion_func,
|
||||
)
|
||||
|
||||
# Currently max/min operations are not available at the integer layer for HPU backend.
|
||||
# Retrieve values by fetching HLAPI layer and insert them into the existing integer array.
|
||||
if backend == Backend.HPU:
|
||||
case_config.layer = Layer.HLApi
|
||||
hlapi_res = conn.fetch_benchmark_data(case_config)
|
||||
hlapi_generic_formatter = GenericFormatter(
|
||||
case_config.layer,
|
||||
case_config.backend,
|
||||
case_config.pbs_kind,
|
||||
case_config.grouping_factor,
|
||||
)
|
||||
hlapi_formatted_results = hlapi_generic_formatter.format_data(
|
||||
hlapi_res,
|
||||
conversion_func,
|
||||
)
|
||||
integer_sizes_fetched = formatted_results["max"].keys()
|
||||
formatted_results["unsigned_max"] = {
|
||||
k: v
|
||||
for k, v in hlapi_formatted_results["max"].items()
|
||||
if k in integer_sizes_fetched
|
||||
}
|
||||
|
||||
generic_arrays = generic_formatter.generate_array(
|
||||
formatted_results,
|
||||
OperandType.CipherText,
|
||||
included_types=[
|
||||
RustType.FheUint64,
|
||||
],
|
||||
)
|
||||
|
||||
resulting_array = generic_arrays[0]
|
||||
resulting_array.replace_column_name(
|
||||
RustType.FheUint64.name, case_config.backend.name
|
||||
)
|
||||
backend_arrays.append(resulting_array)
|
||||
|
||||
print(f"Generating comparison array")
|
||||
|
||||
backend_arrays[0].extend(
|
||||
*backend_arrays[1:], ops_column_name=OPERATION_SIZE_COLUMN_HEADER
|
||||
)
|
||||
|
||||
return [
|
||||
backend_arrays[0],
|
||||
]
|
||||
@@ -18,8 +18,15 @@ import argparse
|
||||
import datetime
|
||||
import formatter
|
||||
import sys
|
||||
from formatter import CSVFormatter, GenericFormatter, MarkdownFormatter, SVGFormatter
|
||||
from formatter import (
|
||||
CSVFormatter,
|
||||
GenericFormatter,
|
||||
MarkdownFormatter,
|
||||
SVGFormatter,
|
||||
BenchArray,
|
||||
)
|
||||
|
||||
import comparison
|
||||
import config
|
||||
import connector
|
||||
import regression
|
||||
@@ -90,6 +97,12 @@ parser.add_argument(
|
||||
default="cpu",
|
||||
help="Backend on which benchmarks have run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends-comparison",
|
||||
dest="backends_comparison",
|
||||
action="store_true",
|
||||
help="Produce a comparison between backends on 64 bits ciphertext/ciphertext integer operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tfhe-rs-layer",
|
||||
dest="layer",
|
||||
@@ -265,8 +278,6 @@ def perform_data_extraction(
|
||||
layer: Layer,
|
||||
operand_type: OperandType,
|
||||
output_filename: str,
|
||||
generate_markdown: bool = False,
|
||||
generate_svg: bool = False,
|
||||
):
|
||||
"""
|
||||
Extracts, formats, and processes benchmark data for a specified operand type and
|
||||
@@ -285,13 +296,9 @@ def perform_data_extraction(
|
||||
:param output_filename: The base filename for the output files where results
|
||||
will be saved.
|
||||
:type output_filename: str
|
||||
:param generate_markdown: Boolean flag indicating whether to generate an
|
||||
output file in Markdown (.md) format.
|
||||
:type generate_markdown: bool
|
||||
:param generate_svg: Boolean flag indicating whether to generate an output
|
||||
file in SVG (.svg) format.
|
||||
:type generate_svg: bool
|
||||
:return: None
|
||||
|
||||
:return: Generic formatted arrays
|
||||
:rtype: list[BenchArray]
|
||||
"""
|
||||
try:
|
||||
res = conn.fetch_benchmark_data(user_config, operand_type)
|
||||
@@ -332,6 +339,18 @@ def perform_data_extraction(
|
||||
excluded_types=[RustType.FheUint2, RustType.FheUint4, RustType.FheUint256],
|
||||
)
|
||||
|
||||
return generic_arrays
|
||||
|
||||
|
||||
def generate_files_from_arrays(
|
||||
generic_arrays: list[BenchArray],
|
||||
user_config: config.UserConfig,
|
||||
layer: Layer,
|
||||
output_filename: str,
|
||||
file_suffix: str = "",
|
||||
generate_markdown: bool = False,
|
||||
generate_svg: bool = False,
|
||||
):
|
||||
for array in generic_arrays:
|
||||
metadata_suffix = ""
|
||||
if array.metadata:
|
||||
@@ -397,6 +416,23 @@ if __name__ == "__main__":
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
||||
if args.backends_comparison:
|
||||
try:
|
||||
arrays = comparison.perform_backends_comparison(conn, user_config)
|
||||
generate_files_from_arrays(
|
||||
arrays,
|
||||
user_config,
|
||||
layer,
|
||||
user_config.output_file,
|
||||
generate_markdown=args.generate_markdown,
|
||||
generate_svg=args.generate_svg,
|
||||
)
|
||||
except RuntimeError as err:
|
||||
print(f"Failed to perform backends comparison: {err}")
|
||||
sys.exit(2)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
||||
hardware_list = (
|
||||
args.hardware_comp.lower().split(",") if args.hardware_comp else None
|
||||
)
|
||||
@@ -412,11 +448,16 @@ if __name__ == "__main__":
|
||||
if layer == Layer.CoreCrypto and operand_type == OperandType.PlainText:
|
||||
continue
|
||||
|
||||
perform_data_extraction(
|
||||
file_suffix = f"_{operand_type.lower()}"
|
||||
arrays = perform_data_extraction(
|
||||
user_config, layer, operand_type, user_config.output_file, file_suffix
|
||||
)
|
||||
generate_files_from_arrays(
|
||||
arrays,
|
||||
user_config,
|
||||
layer,
|
||||
operand_type,
|
||||
user_config.output_file,
|
||||
file_suffix=file_suffix,
|
||||
generate_markdown=args.generate_markdown,
|
||||
generate_svg=args.generate_svg,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
import pathlib
|
||||
import xml.dom.minidom
|
||||
@@ -7,7 +6,7 @@ from collections.abc import Callable
|
||||
|
||||
import svg
|
||||
from benchmark_specs import (
|
||||
ALL_RUST_TYPES,
|
||||
ALL_RUST_INTEGER_TYPES,
|
||||
Backend,
|
||||
BenchDetails,
|
||||
CoreCryptoOperation,
|
||||
@@ -48,6 +47,10 @@ def compute_comparisons(*results):
|
||||
return gains
|
||||
|
||||
|
||||
OPERATION_SIZE_COLUMN_HEADER = "Operation \\ Size"
|
||||
OPERATION_PRECISION_COLUMN_HEADER = "Operation \\ Precision (bits)"
|
||||
|
||||
|
||||
class OperationDisplayName(enum.StrEnum):
|
||||
Negation = "Negation (-)"
|
||||
AddSub = "Add / Sub (+,-)"
|
||||
@@ -67,7 +70,18 @@ class OperationDisplayName(enum.StrEnum):
|
||||
|
||||
|
||||
class BenchArray:
|
||||
def __init__(self, array, layer, metadata: dict = None):
|
||||
"""
|
||||
Represents a structured collection of benchmark data encapsulated with metadata.
|
||||
|
||||
:ivar array: The primary dataset stored as a list of dictionaries.
|
||||
:type array: list[dict]
|
||||
:ivar layer: The associated layer information for this dataset.
|
||||
:type layer: Layer
|
||||
:ivar metadata: Additional metadata associated with the dataset.
|
||||
:type metadata: dict, optional
|
||||
"""
|
||||
|
||||
def __init__(self, array: list[dict], layer: Layer, metadata: dict = None):
|
||||
self.array = array
|
||||
self.layer = layer
|
||||
self.metadata = metadata
|
||||
@@ -75,6 +89,56 @@ class BenchArray:
|
||||
def __repr__(self):
|
||||
return f"BenchArray(layer={self.layer}, metadata={self.metadata})"
|
||||
|
||||
def replace_column_name(self, current: str, new: str):
|
||||
"""
|
||||
Replaces the name of a column for the whole array.
|
||||
If the ``current`` column name does not exist, the array is left unchanged.
|
||||
|
||||
:param current: The column name to be replaced.
|
||||
:type current: str
|
||||
:param new: The new column name to replace the current one.
|
||||
:type new: str
|
||||
:return: None
|
||||
"""
|
||||
for line in self.array:
|
||||
try:
|
||||
line[new] = line.pop(current)
|
||||
except KeyError:
|
||||
# Column name doesn't exist on this line, ignoring
|
||||
continue
|
||||
|
||||
def extend(self, *others, ops_column_name: str = None):
|
||||
"""
|
||||
Extends the current array with values from other benchmark arrays by combining
|
||||
and updating the entries based on a specified column name. This method merges
|
||||
items from the current array and other provided arrays by using the values
|
||||
from the specified column as keys.
|
||||
|
||||
:param others: Additional benchmark arrays to merge into the current array.
|
||||
Each `other` must have a similar structure as the current array.
|
||||
:type others: iterable[BenchArray]
|
||||
:param ops_column_name: The name of the column whose values will be used as
|
||||
keys for merging arrays. This parameter is optional, but required for
|
||||
the merge operation to function correctly.
|
||||
:type ops_column_name: str
|
||||
:return: None
|
||||
"""
|
||||
array_as_dict = {}
|
||||
for item in self.array:
|
||||
op_name = item.pop(ops_column_name)
|
||||
array_as_dict[op_name] = item
|
||||
|
||||
for other_bench_array in others:
|
||||
for item in other_bench_array.array:
|
||||
op_name = item.pop(ops_column_name)
|
||||
array_as_dict[op_name].update(item)
|
||||
|
||||
array_as_list = []
|
||||
for op_name, values in array_as_dict.items():
|
||||
array_as_list.append({ops_column_name: op_name, **values})
|
||||
|
||||
self.array = array_as_list
|
||||
|
||||
|
||||
class GenericFormatter:
|
||||
def __init__(
|
||||
@@ -142,6 +206,8 @@ class GenericFormatter:
|
||||
return self._format_integer_data(data, conversion_func)
|
||||
case Layer.CoreCrypto:
|
||||
return self._format_core_crypto_data(data, conversion_func)
|
||||
case Layer.HLApi:
|
||||
return self._format_hlapi_data(data, conversion_func)
|
||||
case _:
|
||||
raise NotImplementedError(f"layer '{self.layer}' not supported yet")
|
||||
|
||||
@@ -199,10 +265,35 @@ class GenericFormatter:
|
||||
|
||||
return formatted
|
||||
|
||||
@staticmethod
|
||||
def _format_hlapi_data(data: dict[BenchDetails : list[int]], conversion_func):
|
||||
formatted = collections.defaultdict(
|
||||
lambda: {
|
||||
2: "N/A",
|
||||
4: "N/A",
|
||||
8: "N/A",
|
||||
10: "N/A",
|
||||
12: "N/A",
|
||||
14: "N/A",
|
||||
16: "N/A",
|
||||
32: "N/A",
|
||||
64: "N/A",
|
||||
128: "N/A",
|
||||
}
|
||||
)
|
||||
for details, timings in data.items():
|
||||
test_name = details.operation_name.lstrip("ops::")
|
||||
bit_width = details.bit_size
|
||||
value = conversion_func(timings[-1])
|
||||
formatted[test_name][bit_width] = value
|
||||
|
||||
return formatted
|
||||
|
||||
def generate_array(
|
||||
self,
|
||||
data,
|
||||
operand_type: OperandType = None,
|
||||
included_types: list[RustType] = ALL_RUST_INTEGER_TYPES,
|
||||
excluded_types: list[RustType] = None,
|
||||
) -> list[BenchArray]:
|
||||
"""
|
||||
@@ -218,7 +309,11 @@ class GenericFormatter:
|
||||
:param operand_type: Specifies the type of operand to guide the array generation.
|
||||
Defaults to `None`.
|
||||
:type operand_type: OperandType, optional
|
||||
:param included_types: A list of `RustType` to include in array generation.
|
||||
Defaults to `benchmark_specs.ALL_RUST_INTEGER_TYPES`.
|
||||
:type included_types: list[RustType], optional
|
||||
:param excluded_types: A list of `RustType` to exclude from array generation.
|
||||
Note that any type available in excluded_types takes precedence over the same type in included_types.
|
||||
Defaults to `None`.
|
||||
:type excluded_types: list[RustType], optional
|
||||
|
||||
@@ -230,7 +325,7 @@ class GenericFormatter:
|
||||
match self.layer:
|
||||
case Layer.Integer:
|
||||
return self._generate_unsigned_integer_array(
|
||||
data, operand_type, excluded_types
|
||||
data, operand_type, included_types, excluded_types
|
||||
)
|
||||
case Layer.CoreCrypto:
|
||||
return self._generate_core_crypto_showcase_arrays(data)
|
||||
@@ -241,6 +336,7 @@ class GenericFormatter:
|
||||
self,
|
||||
data,
|
||||
operand_type: OperandType = None,
|
||||
included_types: list[RustType] = ALL_RUST_INTEGER_TYPES,
|
||||
excluded_types: list[RustType] = None,
|
||||
):
|
||||
match operand_type:
|
||||
@@ -284,7 +380,7 @@ class GenericFormatter:
|
||||
]
|
||||
case Backend.HPU:
|
||||
operations = [
|
||||
f"{prefix}_neg",
|
||||
f"{prefix}_sub", # Negation operation doesn't exist in HPU yet
|
||||
(
|
||||
f"{prefix}_add"
|
||||
if operand_type == OperandType.CipherText
|
||||
@@ -339,12 +435,12 @@ class GenericFormatter:
|
||||
OperationDisplayName.Select,
|
||||
]
|
||||
|
||||
types = ALL_RUST_TYPES.copy()
|
||||
types = included_types.copy()
|
||||
excluded_types = excluded_types if excluded_types is not None else []
|
||||
for excluded in excluded_types:
|
||||
types.remove(excluded)
|
||||
|
||||
first_column_header = "Operation \\ Size"
|
||||
first_column_header = OPERATION_SIZE_COLUMN_HEADER
|
||||
|
||||
# Adapt list to plaintext benchmarks results.
|
||||
if operand_type == OperandType.PlainText and self.backend != Backend.HPU:
|
||||
@@ -373,14 +469,17 @@ class GenericFormatter:
|
||||
operations.pop(0)
|
||||
display_names.pop(0)
|
||||
|
||||
data_without_excluded_types = copy.deepcopy(data)
|
||||
for v in data_without_excluded_types.values():
|
||||
for excluded in excluded_types:
|
||||
try:
|
||||
v.pop(excluded.value)
|
||||
except KeyError:
|
||||
# Type is not contained in the results, ignoring
|
||||
continue
|
||||
data_without_excluded_types = {}
|
||||
for op, values in data.items():
|
||||
try:
|
||||
data_without_excluded_types[op] = {
|
||||
typ: val
|
||||
for typ, val in values.items()
|
||||
if RustType.from_int(typ) in types
|
||||
}
|
||||
except NotImplementedError:
|
||||
# Unknown type from database, ignoring
|
||||
continue
|
||||
|
||||
filtered_data = filter(lambda t: t in operations, data_without_excluded_types)
|
||||
# Get operation names as key of the dict to ease fetching
|
||||
@@ -493,7 +592,7 @@ class GenericFormatter:
|
||||
# Operation is not supposed to appear in the formatted array.
|
||||
continue
|
||||
|
||||
first_column_header = "Operation \\ Precision (bits)"
|
||||
first_column_header = OPERATION_PRECISION_COLUMN_HEADER
|
||||
|
||||
arrays = []
|
||||
for key, results in sorted_results.items():
|
||||
@@ -705,27 +804,38 @@ class SVGFormatter(GenericFormatter):
|
||||
|
||||
match layer:
|
||||
case Layer.Integer:
|
||||
type_name_width = type_ident.strip("FheUint")
|
||||
header_elements.extend(
|
||||
[
|
||||
# Rust type class
|
||||
if type_ident.startswith("FheUint"):
|
||||
type_name_width = type_ident.strip("FheUint")
|
||||
header_elements.extend(
|
||||
[
|
||||
# Rust type class
|
||||
self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
row_height / 3,
|
||||
"FheUint",
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
),
|
||||
# Actual size of the Rust type
|
||||
self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
2 * row_height / 3 + 3,
|
||||
type_name_width,
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
),
|
||||
]
|
||||
)
|
||||
else: # Backends comparison (CPU, GPU, HPU)
|
||||
header_elements.append(
|
||||
self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
row_height / 3,
|
||||
"FheUint",
|
||||
row_height / 2,
|
||||
type_ident,
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
),
|
||||
# Actual size of the Rust type
|
||||
self._build_svg_text(
|
||||
curr_x + per_timing_col_width / 2,
|
||||
2 * row_height / 3 + 3,
|
||||
type_name_width,
|
||||
fill=WHITE_COLOR,
|
||||
font_weight="bold",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
case Layer.CoreCrypto:
|
||||
header_elements.append(
|
||||
# Core_crypto arrays contains only ciphertext modulus size as headers
|
||||
|
||||
Reference in New Issue
Block a user