mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
WIP: implement tables cases - half done
This commit is contained in:
@@ -50,6 +50,14 @@ class Layer(enum.StrEnum):
|
||||
raise NotImplementedError(f"layer '{layer_name}' not supported")
|
||||
|
||||
|
||||
class OperandSize(int):
|
||||
"""
|
||||
Syntactic sugar for operand sizes handled as integer.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RustType(enum.Enum):
|
||||
"""
|
||||
Represents different integer Rust types used in tfhe-rs.
|
||||
@@ -189,6 +197,26 @@ class OperandType(enum.StrEnum):
|
||||
PlainText = "PlainText"
|
||||
|
||||
|
||||
class AtomicPattern(enum.StrEnum):
|
||||
KSPBS = "KS_PBS"
|
||||
PBSKS = "PBS_KS"
|
||||
KS32PBS = "KS32_PBS"
|
||||
|
||||
@staticmethod
|
||||
def from_str(pattern_name):
|
||||
match pattern_name.lower():
|
||||
case "ks_pbs":
|
||||
return AtomicPattern.KSPBS
|
||||
case "pbs_ks":
|
||||
return AtomicPattern.PBSKS
|
||||
case "ks32_pbs":
|
||||
return AtomicPattern.KS32PBS
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"atomic pattern '{pattern_name}' not supported"
|
||||
)
|
||||
|
||||
|
||||
class PBSKind(enum.StrEnum):
|
||||
"""
|
||||
Represents the kind of parameter set used for Programmable Bootstrapping operation.
|
||||
@@ -279,6 +307,71 @@ class ErrorFailureProbability(enum.IntEnum):
|
||||
return self.to_str()
|
||||
|
||||
|
||||
class GroupingFactor(enum.IntEnum):
|
||||
Two = 2
|
||||
Three = 3
|
||||
Four = 4
|
||||
|
||||
@staticmethod
|
||||
def from_str(gf_value):
|
||||
try:
|
||||
int_value = int(gf_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"grouping factor '{gf_value}' is not an integer")
|
||||
|
||||
match int_value:
|
||||
case 2:
|
||||
return GroupingFactor.Two
|
||||
case 3:
|
||||
return GroupingFactor.Three
|
||||
case 4:
|
||||
return GroupingFactor.Four
|
||||
case _:
|
||||
raise NotImplementedError(f"grouping factor '{gf_value}' not supported")
|
||||
|
||||
|
||||
class Precision(enum.Enum):
|
||||
M1C1 = (1, 1)
|
||||
M2C2 = (2, 2)
|
||||
M3C3 = (3, 3)
|
||||
M4C4 = (4, 4)
|
||||
|
||||
@staticmethod
|
||||
def from_param_name(name: str) -> enum.Enum:
|
||||
parts = name.split("_")
|
||||
message = None
|
||||
carry = None
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if part == "MESSAGE":
|
||||
message = int(parts[i + 1])
|
||||
elif part == "CARRY":
|
||||
carry = int(parts[i + 1])
|
||||
|
||||
if message is None or carry is None:
|
||||
raise ValueError(f"could not extract precision in '{name}'")
|
||||
|
||||
match (message, carry):
|
||||
case (1, 1):
|
||||
return Precision.M1C1
|
||||
case (2, 2):
|
||||
return Precision.M2C2
|
||||
case (3, 3):
|
||||
return Precision.M3C3
|
||||
case (4, 4):
|
||||
return Precision.M4C4
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"precision message={message} carry={carry} not supported yet"
|
||||
)
|
||||
|
||||
def message(self) -> int:
|
||||
return self.value[0]
|
||||
|
||||
def carry(self) -> int:
|
||||
return self.value[1]
|
||||
|
||||
|
||||
class BenchType(enum.Enum):
|
||||
Latency = 0
|
||||
Throughput = 1
|
||||
@@ -309,8 +402,7 @@ class ParamsDefinition:
|
||||
"""
|
||||
|
||||
def __init__(self, param_name: str):
|
||||
self.message_size = None
|
||||
self.carry_size = None
|
||||
self.precision = None
|
||||
self.pbs_kind = None
|
||||
self.grouping_factor = None
|
||||
self.noise_distribution = None
|
||||
@@ -323,8 +415,8 @@ class ParamsDefinition:
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.message_size == other.message_size
|
||||
and self.carry_size == other.carry_size
|
||||
self.precision.message() == other.precision.message()
|
||||
and self.precision.carry() == other.precision.carry()
|
||||
and self.pbs_kind == other.pbs_kind
|
||||
and self.grouping_factor == other.grouping_factor
|
||||
and self.noise_distribution == other.noise_distribution
|
||||
@@ -337,16 +429,16 @@ class ParamsDefinition:
|
||||
def __lt__(self, other):
|
||||
|
||||
return (
|
||||
self.message_size < other.message_size
|
||||
and self.carry_size < other.carry_size
|
||||
self.precision.message() < other.precision.message()
|
||||
and self.precision.carry() < other.precision.carry()
|
||||
and self.p_fail < other.p_fail
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(
|
||||
(
|
||||
self.message_size,
|
||||
self.carry_size,
|
||||
self.precision.message,
|
||||
self.precision.carry,
|
||||
self.pbs_kind,
|
||||
self.grouping_factor,
|
||||
self.noise_distribution,
|
||||
@@ -357,7 +449,47 @@ class ParamsDefinition:
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ParamsDefinition(message_size={self.message_size}, carry_size={self.carry_size}, pbs_kind={self.pbs_kind}, grouping_factor={self.grouping_factor}, noise_distribution={self.noise_distribution}, atomic_pattern={self.atomic_pattern}, p_fail={self.p_fail}, version={self.version}, details={self.details})"
|
||||
return (
|
||||
f"ParamsDefinition("
|
||||
f"message_size={self.precision.message()}, "
|
||||
f"carry_size={self.precision.carry()}, "
|
||||
# f"pbs_kind={self.pbs_kind}, "
|
||||
f"grouping_factor={self.grouping_factor}, "
|
||||
# f"noise_distribution={self.noise_distribution}, "
|
||||
f"atomic_pattern={self.atomic_pattern}, "
|
||||
f"p_fail={self.p_fail}, "
|
||||
# f"version={self.version}, "
|
||||
# f"details={self.details})"
|
||||
)
|
||||
|
||||
def components_match(self, *components):
|
||||
for component in components:
|
||||
match component:
|
||||
case Precision():
|
||||
if self.precision != component:
|
||||
return False
|
||||
case PBSKind():
|
||||
if self.pbs_kind != component:
|
||||
return False
|
||||
case GroupingFactor():
|
||||
if self.grouping_factor != component:
|
||||
return False
|
||||
case NoiseDistribution():
|
||||
if self.noise_distribution != component:
|
||||
return False
|
||||
case AtomicPattern():
|
||||
if self.atomic_pattern != component:
|
||||
return False
|
||||
case ErrorFailureProbability():
|
||||
if self.p_fail != component:
|
||||
return False
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"unsupported component type for matching '{type(component)}'"
|
||||
)
|
||||
|
||||
# Each component matches
|
||||
return True
|
||||
|
||||
def _parse_param_name(self, param_name: str) -> None:
|
||||
split_params = param_name.split("_")
|
||||
@@ -376,6 +508,12 @@ class ParamsDefinition:
|
||||
|
||||
params_variation_parts.append(part)
|
||||
|
||||
try:
|
||||
self.precision = Precision.from_param_name(param_name)
|
||||
except ValueError or NotImplementedError:
|
||||
# Might be a Boolean parameters set
|
||||
raise ParametersFormatNotSupported(param_name)
|
||||
|
||||
try:
|
||||
self.p_fail = ErrorFailureProbability.from_param_name(param_name)
|
||||
pfail_index = split_params.index(self.p_fail.to_str())
|
||||
@@ -395,23 +533,17 @@ class ParamsDefinition:
|
||||
noise_distribution_index = None
|
||||
|
||||
try:
|
||||
self.message_size = int(split_params[split_params.index("MESSAGE") + 1])
|
||||
carry_size_index = split_params.index("CARRY") + 1
|
||||
self.carry_size = int(split_params[carry_size_index])
|
||||
self.atomic_pattern = "_".join(
|
||||
split_params[carry_size_index + 1 : noise_distribution_index]
|
||||
)
|
||||
except ValueError:
|
||||
# Might be a Boolean parameters set
|
||||
raise ParametersFormatNotSupported(param_name)
|
||||
|
||||
try:
|
||||
if noise_distribution_index:
|
||||
self.atomic_pattern = "_".join(
|
||||
split_params[carry_size_index + 1 : noise_distribution_index]
|
||||
self.atomic_pattern = AtomicPattern.from_str(
|
||||
"_".join(
|
||||
split_params[carry_size_index + 1 : noise_distribution_index]
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.atomic_pattern = "_".join(split_params[carry_size_index + 1 :])
|
||||
self.atomic_pattern = AtomicPattern.from_str(
|
||||
"_".join(split_params[carry_size_index + 1 :])
|
||||
)
|
||||
except ValueError:
|
||||
# Might be a Boolean parameters set
|
||||
raise ParametersFormatNotSupported(param_name)
|
||||
@@ -424,7 +556,9 @@ class ParamsDefinition:
|
||||
|
||||
try:
|
||||
# This is a multi-bit parameters set
|
||||
self.grouping_factor = int(split_params[split_params.index("GROUP") + 1])
|
||||
self.grouping_factor = GroupingFactor.from_str(
|
||||
split_params[split_params.index("GROUP") + 1]
|
||||
)
|
||||
self.pbs_kind = PBSKind.MultiBit
|
||||
except ValueError:
|
||||
# This is a classical parameters set
|
||||
@@ -448,7 +582,7 @@ class BenchDetails:
|
||||
:type bit_size: int
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Layer, bench_full_name: str, bit_size: int):
|
||||
def __init__(self, layer: Layer, bench_full_name: str, bit_size: OperandSize):
|
||||
self.layer = layer
|
||||
|
||||
self.operation_name = None
|
||||
|
||||
@@ -9,6 +9,7 @@ from benchmark_specs import (
|
||||
BenchDetails,
|
||||
BenchType,
|
||||
Layer,
|
||||
OperandSize,
|
||||
OperandType,
|
||||
PBSKind,
|
||||
)
|
||||
@@ -302,7 +303,7 @@ class PostgreConnector:
|
||||
raise NoDataFound(msg)
|
||||
|
||||
for line in lines:
|
||||
bit_width = line[1]
|
||||
bit_width = OperandSize(line[1])
|
||||
|
||||
bench_details = BenchDetails(layer, line[0], bit_width)
|
||||
value = line[-1] if last_value_only else line[-3]
|
||||
|
||||
@@ -25,9 +25,9 @@ import formatters.core
|
||||
import formatters.hlapi
|
||||
import formatters.integer
|
||||
import regression
|
||||
import whitepaper
|
||||
from benchmark_specs import BenchType, Layer, OperandType, RustType
|
||||
from formatters.common import BenchArray, CSVFormatter, MarkdownFormatter, SVGFormatter
|
||||
from whitepaper import whitepaper
|
||||
|
||||
import utils
|
||||
|
||||
@@ -442,7 +442,7 @@ if __name__ == "__main__":
|
||||
if args.generate_whitepaper_latex:
|
||||
try:
|
||||
whitepaper.perform_latex_generation(conn, user_config)
|
||||
except RuntimeError as err:
|
||||
except Exception as err:
|
||||
print(f"Failed to perform whitepaper LaTex tables generation: {err}")
|
||||
sys.exit(2)
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import collections
|
||||
import enum
|
||||
import pathlib
|
||||
import xml.dom.minidom
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import svg
|
||||
from benchmark_specs import (
|
||||
@@ -12,6 +14,7 @@ from benchmark_specs import (
|
||||
ErrorFailureProbability,
|
||||
Layer,
|
||||
NoiseDistribution,
|
||||
OperandSize,
|
||||
OperandType,
|
||||
PBSKind,
|
||||
RustType,
|
||||
@@ -189,6 +192,16 @@ class GenericFormatter:
|
||||
# Must be implemented by subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def format_data_with_available_sizes(
|
||||
self, data: dict[BenchDetails : list[int]], conversion_func
|
||||
):
|
||||
return self._format_data_with_available_sizes(data, conversion_func)
|
||||
|
||||
@staticmethod
|
||||
def _format_data_with_available_sizes(*args, **kwargs):
|
||||
# Must be implemented by subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_array(
|
||||
self,
|
||||
data,
|
||||
@@ -659,11 +672,170 @@ class SVGFormatter(GenericFormatter):
|
||||
return self.generate_svg_table(array)
|
||||
|
||||
|
||||
class LatexColumn:
|
||||
pass
|
||||
# TODO faire une sous classe de certaines enum avec juste une méthode to_latex_str() pour avoir la valeur qui correspond a celle affichée dans le whitepaper
|
||||
|
||||
class LatexRow:
|
||||
pass
|
||||
|
||||
class ElementType(enum.Enum):
|
||||
Operation = 0
|
||||
ParamComponent = 1
|
||||
SizeComponent = 2
|
||||
|
||||
|
||||
class LatexSeparator(enum.StrEnum):
|
||||
BottomRule = r"\bottomrule"
|
||||
MidRule = r"\midrule"
|
||||
HLine = r"\hline"
|
||||
|
||||
|
||||
class LatexElement:
|
||||
def __init__(
|
||||
self,
|
||||
elem: Any,
|
||||
elem_type: ElementType,
|
||||
latex_str: str,
|
||||
display_element: bool = True,
|
||||
):
|
||||
self.elem = elem
|
||||
self.type = elem_type
|
||||
self.latex_str = latex_str
|
||||
self.display_element = display_element
|
||||
|
||||
def format(self):
|
||||
return self.latex_str
|
||||
|
||||
|
||||
class LatexRowElement(LatexElement):
|
||||
def __init__(
|
||||
self, *args, display_as_column: bool = False, column_span: int = None, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.display_as_column = display_as_column
|
||||
self.column_span = column_span or 1
|
||||
|
||||
def format(self):
|
||||
if self.display_as_column:
|
||||
return "\n".join(
|
||||
[
|
||||
f"{LatexSeparator.MidRule}",
|
||||
f"\\multicolumn{{{self.column_span}}}{{c}}{{\\textbf{{{self.latex_str}}}}} \\\\",
|
||||
f"{LatexSeparator.MidRule}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
return self.latex_str
|
||||
|
||||
|
||||
class LatexColumnElement(LatexElement):
|
||||
def __init__(self, *args, sub_cols=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sub_columns: list[LatexColumnElement] = sub_cols or []
|
||||
|
||||
|
||||
class LatexArraySection:
|
||||
def __init__(self, rows: list):
|
||||
self.rows = rows
|
||||
# TODO intégrer ici la notion de mise en gras du/des résultats les plus petits (enum ?)
|
||||
# Ca peut être le minimum sur une ligne (ou un partie de ligne, cf multi-bit core)
|
||||
# ou bien le minimum sur une section entière (cf classic core)
|
||||
|
||||
def format(self, bench_data):
|
||||
lines = []
|
||||
|
||||
# IMPORTANT: data must be sorted with row discriminant first and then column discriminant that would give access to the value
|
||||
for row in self.rows:
|
||||
try:
|
||||
lines.append(self._format_row(row, bench_data))
|
||||
except TypeError:
|
||||
# Row is a single LatexRowElement
|
||||
lines.append(row.format())
|
||||
|
||||
print("Lines:", lines) # DEBUG
|
||||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _format_row(row, bench_data):
|
||||
operation = None
|
||||
row_filters = []
|
||||
row_values = []
|
||||
|
||||
for elem in row:
|
||||
if isinstance(elem, LatexRowElement):
|
||||
match elem.type:
|
||||
case ElementType.Operation:
|
||||
operation = elem.elem
|
||||
case ElementType.ParamComponent:
|
||||
row_filters.append(elem.elem)
|
||||
|
||||
if elem.display_element:
|
||||
row_values.append(elem.format())
|
||||
elif isinstance(elem, LatexColumnElement):
|
||||
bench_values = bench_data.get(operation)
|
||||
if elem.sub_columns:
|
||||
for sub_col in elem.sub_columns:
|
||||
for param_def, value in bench_values.items():
|
||||
if isinstance(sub_col.elem, OperandSize):
|
||||
bit_size = sub_col.elem
|
||||
if param_def.components_match(*row_filters, elem.elem):
|
||||
try:
|
||||
row_values.append(value[bit_size])
|
||||
except KeyError:
|
||||
# TODO maybe returning a "N/A" value would be better than raising an error
|
||||
raise KeyError(
|
||||
f"bit size '{bit_size}' not found in bench data (params: {param_def}, values: {value})"
|
||||
)
|
||||
else:
|
||||
if param_def.components_match(
|
||||
*row_filters, elem.elem, sub_col.elem
|
||||
):
|
||||
# Only one value is stored in the operand size dict.
|
||||
_, v = value.popitem()
|
||||
row_values.append(v)
|
||||
else:
|
||||
for param_def, value in bench_values.items():
|
||||
if isinstance(
|
||||
value, dict
|
||||
): # TODO If(condition) et else(content) à supprimer
|
||||
# Handling data from the integer tfhe-rs layer.
|
||||
bit_size = elem.elem
|
||||
if param_def.components_match(*row_filters):
|
||||
try:
|
||||
row_values.append(value[bit_size])
|
||||
except KeyError:
|
||||
# TODO maybe returning a "N/A" value would be better than raising an error
|
||||
raise KeyError(
|
||||
f"bit size '{bit_size}' not found in bench data (params: {param_def}, values: {value})"
|
||||
)
|
||||
else:
|
||||
# Handling data from the core-crypto tfhe-rs layer.
|
||||
if param_def.components_match(*row_filters, elem.elem):
|
||||
row_values.append(value)
|
||||
|
||||
print("Row values:", row_values)
|
||||
return " & ".join(row_values) + r" \\"
|
||||
|
||||
|
||||
class LatexTable:
|
||||
def __init__(
|
||||
self,
|
||||
array_section: list[LatexArraySection],
|
||||
):
|
||||
self.array_sections = array_section
|
||||
|
||||
@staticmethod
|
||||
def get_separator(sep: LatexSeparator):
|
||||
return sep.value
|
||||
|
||||
def format_table(self, bench_data):
|
||||
parts = []
|
||||
|
||||
for section in self.array_sections:
|
||||
parts.extend(section.format(bench_data))
|
||||
parts.append(self.get_separator(LatexSeparator.HLine))
|
||||
|
||||
# Replace the last section separator with a bottom rule
|
||||
parts[-1] = self.get_separator(LatexSeparator.BottomRule)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
class LatexFormatter(GenericFormatter):
|
||||
|
||||
@@ -88,6 +88,22 @@ class CoreFormatter(GenericFormatter):
|
||||
|
||||
return formatted
|
||||
|
||||
@staticmethod
|
||||
def _format_data_with_available_sizes(
|
||||
data: dict[BenchDetails : list[int]], conversion_func
|
||||
):
|
||||
formatted = collections.defaultdict(lambda: collections.defaultdict(lambda: {}))
|
||||
|
||||
for details, timings in data.items():
|
||||
reduced_params = details.get_params_definition()
|
||||
test_name = details.operation_name
|
||||
bit_width = details.bit_size
|
||||
value = conversion_func(timings[-1])
|
||||
|
||||
formatted[test_name][reduced_params][bit_width] = value
|
||||
|
||||
return formatted
|
||||
|
||||
def _generate_arrays(
|
||||
self,
|
||||
data,
|
||||
|
||||
@@ -56,6 +56,22 @@ class IntegerFormatter(GenericFormatter):
|
||||
|
||||
return formatted
|
||||
|
||||
@staticmethod
|
||||
def _format_data_with_available_sizes(
|
||||
data: dict[BenchDetails : list[int]], conversion_func
|
||||
):
|
||||
formatted = collections.defaultdict(lambda: collections.defaultdict(lambda: {}))
|
||||
|
||||
for details, timings in data.items():
|
||||
reduced_params = details.get_params_definition()
|
||||
test_name = "_".join((details.sign_flavor.value, details.operation_name))
|
||||
bit_width = details.bit_size
|
||||
value = conversion_func(timings[-1])
|
||||
|
||||
formatted[test_name][reduced_params][bit_width] = value
|
||||
|
||||
return formatted
|
||||
|
||||
def _generate_arrays(
|
||||
self,
|
||||
data,
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
import copy
|
||||
import itertools
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import config
|
||||
import connector
|
||||
from benchmark_specs import ErrorFailureProbability
|
||||
from formatters.core import CoreFormatter
|
||||
import utils
|
||||
|
||||
class Default(dict):
|
||||
def __missing__(self, key):
|
||||
return f"{{{key}}}"
|
||||
|
||||
|
||||
class ParametersFilterCase:
|
||||
def __init__(
|
||||
self,
|
||||
param_name_pattern: str,
|
||||
pfails: list[ErrorFailureProbability] = None,
|
||||
grouping_factors: list[int] = None,
|
||||
message_carry_sizes: list[int] = None,
|
||||
):
|
||||
self.param_name_pattern = param_name_pattern
|
||||
self.pfails = pfails or []
|
||||
self.grouping_factors = grouping_factors or []
|
||||
self.message_carry_sizes = message_carry_sizes or []
|
||||
|
||||
def get_parameter_variants(self):
|
||||
after_pfails = []
|
||||
for pfail in self.pfails:
|
||||
after_pfails.append(
|
||||
self.param_name_pattern.format_map(Default(pfail=pfail.to_str()))
|
||||
)
|
||||
|
||||
after_grouping_factors = []
|
||||
if after_pfails:
|
||||
for name in after_pfails:
|
||||
for gf in self.grouping_factors:
|
||||
after_grouping_factors.append(name.format_map(Default(gf=gf)))
|
||||
else:
|
||||
for gf in self.grouping_factors:
|
||||
after_grouping_factors.append(
|
||||
self.grouping_factors.format_map(Default(gf=gf))
|
||||
)
|
||||
|
||||
after_msg_carry_sizes = []
|
||||
if after_grouping_factors:
|
||||
for name in after_grouping_factors:
|
||||
for size in self.message_carry_sizes:
|
||||
after_msg_carry_sizes.append(
|
||||
name.format_map(Default(msg=size, carry=size))
|
||||
)
|
||||
else:
|
||||
for size in self.message_carry_sizes:
|
||||
after_msg_carry_sizes.append(
|
||||
self.param_name_pattern.format_map(Default(msg=size, carry=size))
|
||||
)
|
||||
|
||||
return (
|
||||
after_msg_carry_sizes
|
||||
or after_grouping_factors
|
||||
or after_pfails
|
||||
or [self.param_name_pattern]
|
||||
)
|
||||
|
||||
|
||||
CORE_CRYPTO_PARAM_CASES = [
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
),
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
grouping_factors=[2, 3, 4],
|
||||
message_carry_sizes=[1, 2, 3, 4],
|
||||
),
|
||||
]
|
||||
|
||||
INTEGER_PARAM_CASES = [
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}", # 1_1, 2_2, 4_4 (pfail: 2m64, 2m128)
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
message_carry_sizes=[1, 2, 4],
|
||||
),
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
),
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
grouping_factors=[2, 3, 4],
|
||||
),
|
||||
ParametersFilterCase(
|
||||
"%COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
INTEGER_SPECIAL_CASE_OPERATIONS_FILTER = [
|
||||
"add_parallelized",
|
||||
"mul_parallelized",
|
||||
"bitand_parallelized",
|
||||
]
|
||||
|
||||
def _generate_latex_tables(
|
||||
conn: connector.PostgreConnector,
|
||||
user_config: config.UserConfig,
|
||||
result_dir: pathlib.Path,
|
||||
):
|
||||
conversion_func = utils.convert_latency_value_to_readable_text
|
||||
|
||||
case_config = copy.deepcopy(user_config)
|
||||
case_config.backend = config.Backend.CPU
|
||||
case_config.pbs_kind = config.PBSKind.Any
|
||||
|
||||
for case in CORE_CRYPTO_PARAM_CASES:
|
||||
case_config.layer = config.Layer.CoreCrypto
|
||||
param_patterns = case.get_parameter_variants()
|
||||
res = conn.fetch_benchmark_data(case_config, param_name_patterns=param_patterns)
|
||||
|
||||
generic_formatter = CoreFormatter(
|
||||
case_config.layer,
|
||||
case_config.backend,
|
||||
case_config.pbs_kind,
|
||||
case_config.grouping_factor,
|
||||
)
|
||||
formatted_results = generic_formatter.format_data(
|
||||
res,
|
||||
conversion_func,
|
||||
)
|
||||
|
||||
# TODO créer le tbaleau qui va bien en fonction du cas
|
||||
for r in formatted_results.items(): # DEBUG
|
||||
print(r)
|
||||
|
||||
print("--------------------------------------------------")
|
||||
print("--------------------------------------------------")
|
||||
print("--------------------------------------------------")
|
||||
|
||||
# TODO prendre la valeur minimum dans un groupe d'opération (ex: min/max, gt/ge/lt/le, ...)
|
||||
|
||||
|
||||
def perform_latex_generation(
|
||||
conn: connector.PostgreConnector, user_config: config.UserConfig
|
||||
):
|
||||
dir_path = user_config.output_file
|
||||
try:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e: ## TODO find the exact exception that can be raised here
|
||||
pass
|
||||
|
||||
# TODO passer un dossier en user_config et enregistrer les tables latex dans ce dossier un fichier = une table
|
||||
_generate_latex_tables(conn, user_config, dir_path)
|
||||
|
||||
|
||||
# Table 2 is core crypto KS/KS-PBS/PBS for 2_2_128 and 2_2_64
|
||||
# Table 3 is core crypto KS/KS-PBS/PBS in the multibit case for V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y in [1,2,3,4] and Z in [64, 128]
|
||||
# Table 5 is special case with 1_1_64, 2_2_64, 4_4_64
|
||||
# Table 6 is special case with 1_1_128, 2_2_128, 4_4_128
|
||||
# Table 7 is the multibit special case with V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y = 2, Z = 64.
|
||||
# Table 8 is special case for 2_2_128 and 2_2_128 KS_32 <--- changed
|
||||
# Table 9, 10, 11, 12 are the integer ops for 2_2_64_KS32 and 2_2_128_KS32
|
||||
# Table 13 is compression (which I believe was in special case?) for the 64/128 compression parameters
|
||||
1
ci/data_extractor/src/whitepaper/__init__.py
Normal file
1
ci/data_extractor/src/whitepaper/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from . import whitepaper
|
||||
627
ci/data_extractor/src/whitepaper/tables.py
Normal file
627
ci/data_extractor/src/whitepaper/tables.py
Normal file
@@ -0,0 +1,627 @@
|
||||
"""
|
||||
All tables are named after their label in the whitepaper.
|
||||
"""
|
||||
|
||||
from benchmark_specs import (
|
||||
AtomicPattern,
|
||||
ErrorFailureProbability,
|
||||
GroupingFactor,
|
||||
OperandSize,
|
||||
Precision,
|
||||
)
|
||||
from formatters.common import (
|
||||
ElementType,
|
||||
LatexArraySection,
|
||||
LatexColumnElement,
|
||||
LatexRowElement,
|
||||
LatexTable,
|
||||
)
|
||||
|
||||
# Table 2 is core crypto KS/KS-PBS/PBS for 2_2_128 and 2_2_64
|
||||
# Table 3 is core crypto KS/KS-PBS/PBS in the multibit case for V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y in [1,2,3,4] and Z in [64, 128]
|
||||
# Table 5 is special case with 1_1_64, 2_2_64, 4_4_64
|
||||
# Table 6 is special case with 1_1_128, 2_2_128, 4_4_128
|
||||
# Table 7 is the multibit special case with V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y = 2, Z = 64.
|
||||
# Table 8 is special case for 2_2_128 and 2_2_128 KS_32 <--- changed
|
||||
# Table 9, 10, 11, 12 are the integer ops for 2_2_64_KS32 and 2_2_128_KS32
|
||||
# Table 13 is compression (which I believe was in special case?) for the 64/128 compression parameters
|
||||
|
||||
# ------------------
|
||||
# LaTex row elements
|
||||
# ------------------
|
||||
|
||||
|
||||
def _get_operation_elem(
|
||||
operation_name: str,
|
||||
latex_str: str,
|
||||
display_element: bool = True,
|
||||
display_as_column: bool = False,
|
||||
column_span: int = 1,
|
||||
) -> LatexRowElement:
|
||||
return LatexRowElement(
|
||||
operation_name,
|
||||
ElementType.Operation,
|
||||
latex_str,
|
||||
display_element=display_element,
|
||||
display_as_column=display_as_column,
|
||||
column_span=column_span,
|
||||
)
|
||||
|
||||
|
||||
PFAIL_2M64_ELEM = LatexRowElement(
|
||||
ErrorFailureProbability.TWO_MINUS_64, ElementType.ParamComponent, r"\(2^{-64}\)"
|
||||
)
|
||||
PFAIL_2M128_ELEM = LatexRowElement(
|
||||
ErrorFailureProbability.TWO_MINUS_128, ElementType.ParamComponent, r"\(2^{-128}\)"
|
||||
)
|
||||
|
||||
PFAIL_2M64_HIDDEN_ELEM = LatexRowElement(
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ElementType.ParamComponent,
|
||||
r"\(2^{-64}\)",
|
||||
display_element=False,
|
||||
)
|
||||
PFAIL_2M128_HIDDEN_ELEM = LatexRowElement(
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
ElementType.ParamComponent,
|
||||
r"\(2^{-128}\)",
|
||||
display_element=False,
|
||||
)
|
||||
|
||||
M1C1_PFAIL_64_ELEM = LatexRowElement(
|
||||
Precision.M1C1,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 1\_1\_64}$",
|
||||
)
|
||||
M2C2_PFAIL_64_ELEM = LatexRowElement(
|
||||
Precision.M2C2,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 2\_2\_64}$",
|
||||
)
|
||||
M4C4_PFAIL_64_ELEM = LatexRowElement(
|
||||
Precision.M4C4,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 4\_4\_64}$",
|
||||
)
|
||||
M1C1_PFAIL_128_ELEM = LatexRowElement(
|
||||
Precision.M1C1,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 1\_1\_128}$",
|
||||
)
|
||||
M2C2_PFAIL_128_ELEM = LatexRowElement(
|
||||
Precision.M2C2,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 2\_2\_128}$",
|
||||
)
|
||||
M2C2_PFAIL_128_KS32_ELEM = LatexRowElement(
|
||||
Precision.M2C2,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 2\_2\_128\_KS32}$",
|
||||
)
|
||||
M4C4_PFAIL_128_ELEM = LatexRowElement(
|
||||
Precision.M4C4,
|
||||
ElementType.ParamComponent,
|
||||
r"${\tt 4\_4\_128}$",
|
||||
)
|
||||
|
||||
KSPBS_HIDDEN_ELEM = LatexRowElement(
|
||||
AtomicPattern.KSPBS,
|
||||
ElementType.ParamComponent,
|
||||
"",
|
||||
display_element=False,
|
||||
)
|
||||
KS32PBS_HIDDEN_ELEM = LatexRowElement(
|
||||
AtomicPattern.KS32PBS,
|
||||
ElementType.ParamComponent,
|
||||
"",
|
||||
display_element=False,
|
||||
)
|
||||
|
||||
KS_OP_ELEM = _get_operation_elem("keyswitch", r"\ks")
|
||||
PBS_OP_ELEM = _get_operation_elem("pbs_mem_optimized", r"\pbs")
|
||||
MB_PBS_OP_ELEM = _get_operation_elem("multi_bit_deterministic_pbs", r"\mbpbs")
|
||||
KSPBS_OP_ELEM = _get_operation_elem("ks_pbs", r"\kspbs")
|
||||
KS_MB_PBS_OP_ELEM = _get_operation_elem("multi_bit_deterministic_ks_pbs", r"\ksmbpbs")
|
||||
|
||||
ADD_OP_HIDDEN_ELEM = _get_operation_elem(
|
||||
"unsigned_add_parallelized", "Addition", display_element=False
|
||||
)
|
||||
BITAND_OP_HIDDEN_ELEM = _get_operation_elem(
|
||||
"unsigned_bitand_parallelized", "Bitwise AND", display_element=False
|
||||
)
|
||||
MUL_OP_HIDDEN_ELEM = _get_operation_elem(
|
||||
"unsigned_mul_parallelized", "Multiplication", display_element=False
|
||||
)
|
||||
|
||||
# ---------------------
|
||||
# LaTex column elements
|
||||
# ---------------------
|
||||
|
||||
|
||||
def _get_precision_column_element(
|
||||
precision: Precision, sub_cols: list[LatexColumnElement] = None
|
||||
) -> LatexColumnElement:
|
||||
return LatexColumnElement(
|
||||
precision, ElementType.ParamComponent, "", sub_cols=sub_cols
|
||||
)
|
||||
|
||||
|
||||
def _get_grouping_factor_column_element(
|
||||
grouping_factor: GroupingFactor,
|
||||
) -> LatexColumnElement:
|
||||
return LatexColumnElement(
|
||||
grouping_factor, ElementType.ParamComponent, str(grouping_factor)
|
||||
)
|
||||
|
||||
|
||||
def _get_operand_size_column_element(
|
||||
op_size: int,
|
||||
) -> LatexColumnElement:
|
||||
return LatexColumnElement(
|
||||
OperandSize(op_size), ElementType.SizeComponent, str(op_size)
|
||||
)
|
||||
|
||||
|
||||
# Operand size is set to the value of the message size since it's stored as is in the database.
|
||||
M1C1_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M1C1, sub_cols=[_get_operand_size_column_element(1)]
|
||||
)
|
||||
M2C2_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M2C2, sub_cols=[_get_operand_size_column_element(2)]
|
||||
)
|
||||
M3C3_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M3C3, sub_cols=[_get_operand_size_column_element(3)]
|
||||
)
|
||||
M4C4_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M4C4, sub_cols=[_get_operand_size_column_element(4)]
|
||||
)
|
||||
|
||||
ALL_GROUPING_FACTORS_ELEM = [
|
||||
_get_grouping_factor_column_element(GroupingFactor.Two),
|
||||
_get_grouping_factor_column_element(GroupingFactor.Three),
|
||||
_get_grouping_factor_column_element(GroupingFactor.Four),
|
||||
]
|
||||
|
||||
M1C1_ALL_GF_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M1C1, sub_cols=ALL_GROUPING_FACTORS_ELEM
|
||||
)
|
||||
M2C2_ALL_GF_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M2C2, sub_cols=ALL_GROUPING_FACTORS_ELEM
|
||||
)
|
||||
M3C3_ALL_GF_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M3C3, sub_cols=ALL_GROUPING_FACTORS_ELEM
|
||||
)
|
||||
M4C4_ALL_GF_COL_ELEM = _get_precision_column_element(
|
||||
Precision.M4C4, sub_cols=ALL_GROUPING_FACTORS_ELEM
|
||||
)
|
||||
|
||||
ALL_OPERAND_SIZES_ELEM = [
|
||||
_get_operand_size_column_element(4),
|
||||
_get_operand_size_column_element(8),
|
||||
_get_operand_size_column_element(16),
|
||||
_get_operand_size_column_element(32),
|
||||
_get_operand_size_column_element(64),
|
||||
_get_operand_size_column_element(128),
|
||||
_get_operand_size_column_element(256),
|
||||
]
|
||||
|
||||
OP_AS_COL_SPAN = len(ALL_OPERAND_SIZES_ELEM) + 1
|
||||
|
||||
# Operations used in integer special case tables
|
||||
ADD_OP_AS_COL_ELEM = _get_operation_elem(
|
||||
"unsigned_add_parallelized",
|
||||
"Addition",
|
||||
display_as_column=True,
|
||||
column_span=OP_AS_COL_SPAN,
|
||||
)
|
||||
BITAND_OP_AS_COL_ELEM = _get_operation_elem(
|
||||
"unsigned_bitand_parallelized",
|
||||
"Bitwise AND",
|
||||
display_as_column=True,
|
||||
column_span=OP_AS_COL_SPAN,
|
||||
)
|
||||
MUL_OP_AS_COL_ELEM = _get_operation_elem(
|
||||
"unsigned_mul_parallelized",
|
||||
"Multiplication",
|
||||
display_as_column=True,
|
||||
column_span=OP_AS_COL_SPAN,
|
||||
)
|
||||
|
||||
|
||||
# TODO On a besoin de garder quelque part la raw_value du bench (après conversion str) pour effectuer des comparaisons et trouver le minimum sur un groupe de résultats
|
||||
# TODO Calculer les lignes "amortized" pour TABLE_2
|
||||
|
||||
# -----------------------
|
||||
# LaTex table definitions
|
||||
# -----------------------
|
||||
|
||||
TABLE_PBS_BENCH = LatexTable(
|
||||
[
|
||||
LatexArraySection(
|
||||
[
|
||||
[
|
||||
PFAIL_2M64_ELEM,
|
||||
KS_OP_ELEM,
|
||||
M1C1_COL_ELEM,
|
||||
M2C2_COL_ELEM,
|
||||
M3C3_COL_ELEM,
|
||||
M4C4_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M64_ELEM,
|
||||
PBS_OP_ELEM,
|
||||
M1C1_COL_ELEM,
|
||||
M2C2_COL_ELEM,
|
||||
M3C3_COL_ELEM,
|
||||
M4C4_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M64_ELEM,
|
||||
KSPBS_OP_ELEM,
|
||||
M1C1_COL_ELEM,
|
||||
M2C2_COL_ELEM,
|
||||
M3C3_COL_ELEM,
|
||||
M4C4_COL_ELEM,
|
||||
],
|
||||
# [PFAIL_2M64_ELEM, KSPBS_OP_ELEM_AMORTIZED, M1C1_COL_ELEM, M2C2_COL_ELEM, M3C3_COL_ELEM, M4C4_COL_ELEM], # TODO line de données à calculer depuis les résultats
|
||||
]
|
||||
),
|
||||
LatexArraySection(
|
||||
[
|
||||
[
|
||||
PFAIL_2M128_ELEM,
|
||||
KS_OP_ELEM,
|
||||
M1C1_COL_ELEM,
|
||||
M2C2_COL_ELEM,
|
||||
M3C3_COL_ELEM,
|
||||
M4C4_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M128_ELEM,
|
||||
PBS_OP_ELEM,
|
||||
M1C1_COL_ELEM,
|
||||
M2C2_COL_ELEM,
|
||||
M3C3_COL_ELEM,
|
||||
M4C4_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M128_ELEM,
|
||||
KSPBS_OP_ELEM,
|
||||
M1C1_COL_ELEM,
|
||||
M2C2_COL_ELEM,
|
||||
M3C3_COL_ELEM,
|
||||
M4C4_COL_ELEM,
|
||||
],
|
||||
# [PFAIL_2M128_ELEM, KSPBS_OP_ELEM_AMORTIZED, M1C1_COL_ELEM, M2C2_COL_ELEM, Precision.M3C3, M4C4_COL_ELEM],
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
TABLE_BENCH_MULTIBIT_BY_PRECISION = LatexTable(
|
||||
[
|
||||
LatexArraySection(
|
||||
[
|
||||
[
|
||||
PFAIL_2M64_ELEM,
|
||||
KS_OP_ELEM,
|
||||
M1C1_ALL_GF_COL_ELEM,
|
||||
M2C2_ALL_GF_COL_ELEM,
|
||||
M3C3_ALL_GF_COL_ELEM,
|
||||
M4C4_ALL_GF_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M64_ELEM,
|
||||
MB_PBS_OP_ELEM,
|
||||
M1C1_ALL_GF_COL_ELEM,
|
||||
M2C2_ALL_GF_COL_ELEM,
|
||||
M3C3_ALL_GF_COL_ELEM,
|
||||
M4C4_ALL_GF_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M64_ELEM,
|
||||
KS_MB_PBS_OP_ELEM,
|
||||
M1C1_ALL_GF_COL_ELEM,
|
||||
M2C2_ALL_GF_COL_ELEM,
|
||||
M3C3_ALL_GF_COL_ELEM,
|
||||
M4C4_ALL_GF_COL_ELEM,
|
||||
],
|
||||
]
|
||||
),
|
||||
LatexArraySection(
|
||||
[
|
||||
[
|
||||
PFAIL_2M128_ELEM,
|
||||
KS_OP_ELEM,
|
||||
M1C1_ALL_GF_COL_ELEM,
|
||||
M2C2_ALL_GF_COL_ELEM,
|
||||
M3C3_ALL_GF_COL_ELEM,
|
||||
M4C4_ALL_GF_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M128_ELEM,
|
||||
MB_PBS_OP_ELEM,
|
||||
M1C1_ALL_GF_COL_ELEM,
|
||||
M2C2_ALL_GF_COL_ELEM,
|
||||
M3C3_ALL_GF_COL_ELEM,
|
||||
M4C4_ALL_GF_COL_ELEM,
|
||||
],
|
||||
[
|
||||
PFAIL_2M128_ELEM,
|
||||
KS_MB_PBS_OP_ELEM,
|
||||
M1C1_ALL_GF_COL_ELEM,
|
||||
M2C2_ALL_GF_COL_ELEM,
|
||||
M3C3_ALL_GF_COL_ELEM,
|
||||
M4C4_ALL_GF_COL_ELEM,
|
||||
],
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL64 = LatexTable(
|
||||
[
|
||||
LatexArraySection(
|
||||
[
|
||||
BITAND_OP_AS_COL_ELEM,
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
ADD_OP_AS_COL_ELEM,
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
MUL_OP_AS_COL_ELEM,
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL128 = LatexTable(
|
||||
[
|
||||
LatexArraySection(
|
||||
[
|
||||
BITAND_OP_AS_COL_ELEM,
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_64_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
ADD_OP_AS_COL_ELEM,
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
MUL_OP_AS_COL_ELEM,
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M64_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING = LatexTable(
|
||||
[
|
||||
LatexArraySection(
|
||||
[
|
||||
BITAND_OP_AS_COL_ELEM,
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
ADD_OP_AS_COL_ELEM,
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
MUL_OP_AS_COL_ELEM,
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M1C1_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
M4C4_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING128KS32 = LatexTable(
|
||||
[
|
||||
LatexArraySection(
|
||||
[
|
||||
BITAND_OP_AS_COL_ELEM,
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
KSPBS_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
BITAND_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
KS32PBS_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_KS32_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
ADD_OP_AS_COL_ELEM,
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
KSPBS_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
ADD_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
KS32PBS_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_KS32_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
MUL_OP_AS_COL_ELEM,
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
KSPBS_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
[
|
||||
MUL_OP_HIDDEN_ELEM,
|
||||
PFAIL_2M128_HIDDEN_ELEM,
|
||||
KS32PBS_HIDDEN_ELEM,
|
||||
M2C2_PFAIL_128_KS32_ELEM,
|
||||
*ALL_OPERAND_SIZES_ELEM,
|
||||
],
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# # No LaTex set for this table.
|
||||
# table_plaintext_ciphertext_ops_pfail64_ks32 = LatexTable()
|
||||
#
|
||||
# # No LaTex set for this table.
|
||||
# table_plaintext_ciphertext_ops_pfail128_ks32 = LatexTable()
|
||||
#
|
||||
# # No LaTex set for this table.
|
||||
# table_ciphertext_ciphertext_ops_pfail64_ks32 = LatexTable()
|
||||
#
|
||||
# # No LaTex set for this table.
|
||||
# table_ciphertext_ciphertext_ops_pfail128_ks32 = LatexTable()
|
||||
#
|
||||
# table_compression_benchmarks = LatexTable()
|
||||
269
ci/data_extractor/src/whitepaper/whitepaper.py
Normal file
269
ci/data_extractor/src/whitepaper/whitepaper.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import copy
|
||||
import itertools
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import utils
|
||||
|
||||
import config, connector
|
||||
from benchmark_specs import (
|
||||
AtomicPattern,
|
||||
ErrorFailureProbability,
|
||||
GroupingFactor,
|
||||
Precision,
|
||||
)
|
||||
from formatters.common import (
|
||||
LatexTable,
|
||||
)
|
||||
from formatters.core import CoreFormatter
|
||||
from formatters.integer import IntegerFormatter
|
||||
from .tables import (
|
||||
TABLE_BENCH_MULTIBIT_BY_PRECISION,
|
||||
TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING128KS32,
|
||||
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL64,
|
||||
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL128,
|
||||
TABLE_PBS_BENCH,
|
||||
)
|
||||
|
||||
|
||||
class Default(dict):
|
||||
def __missing__(self, key):
|
||||
return f"{{{key}}}"
|
||||
|
||||
|
||||
class ParametersFilterCase:
|
||||
def __init__(
|
||||
self,
|
||||
param_name_pattern: str,
|
||||
pfails: list[ErrorFailureProbability] = None,
|
||||
grouping_factors: list[GroupingFactor] = None,
|
||||
precisions: list[Precision] = None,
|
||||
atomic_patterns: list[AtomicPattern] = None,
|
||||
associated_tables: list[LatexTable] = None,
|
||||
):
|
||||
self.param_name_pattern = param_name_pattern
|
||||
self.pfails = pfails or []
|
||||
self.grouping_factors = grouping_factors or []
|
||||
self.precisions = precisions or []
|
||||
self.atomic_patterns = atomic_patterns or []
|
||||
|
||||
self.associated_tables = associated_tables or []
|
||||
|
||||
def get_parameter_variants(self):
|
||||
after_pfails = []
|
||||
for pfail in self.pfails:
|
||||
after_pfails.append(
|
||||
self.param_name_pattern.format_map(Default(pfail=pfail.to_str()))
|
||||
)
|
||||
|
||||
after_grouping_factors = []
|
||||
if after_pfails:
|
||||
for name in after_pfails:
|
||||
for gf in self.grouping_factors:
|
||||
after_grouping_factors.append(name.format_map(Default(gf=gf)))
|
||||
else:
|
||||
for gf in self.grouping_factors:
|
||||
after_grouping_factors.append(
|
||||
self.grouping_factors.format_map(Default(gf=gf))
|
||||
)
|
||||
|
||||
last_populated = after_grouping_factors or after_pfails
|
||||
after_msg_carry_sizes = []
|
||||
if last_populated:
|
||||
for name in last_populated:
|
||||
for p in self.precisions:
|
||||
after_msg_carry_sizes.append(
|
||||
name.format_map(Default(msg=p.message(), carry=p.carry()))
|
||||
)
|
||||
else:
|
||||
for p in self.precisions:
|
||||
after_msg_carry_sizes.append(
|
||||
self.param_name_pattern.format_map(
|
||||
Default(msg=p.message(), carry=p.carry())
|
||||
)
|
||||
)
|
||||
|
||||
last_populated = last_populated or after_msg_carry_sizes
|
||||
after_atomic_patterns = []
|
||||
if last_populated:
|
||||
for name in last_populated:
|
||||
for a in self.atomic_patterns:
|
||||
after_atomic_patterns.append(
|
||||
name.format_map(Default(atomic_pattern=a))
|
||||
)
|
||||
else:
|
||||
for a in self.atomic_patterns:
|
||||
after_atomic_patterns.append(
|
||||
self.param_name_pattern.format_map(Default(atomic_pattern=a))
|
||||
)
|
||||
|
||||
return last_populated or after_atomic_patterns or [self.param_name_pattern]
|
||||
|
||||
|
||||
CORE_CRYPTO_PARAM_CASES = [
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
precisions=[Precision.M1C1, Precision.M2C2, Precision.M3C3, Precision.M4C4],
|
||||
associated_tables=[TABLE_PBS_BENCH],
|
||||
),
|
||||
ParametersFilterCase(
|
||||
"%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
grouping_factors=[
|
||||
GroupingFactor.Two,
|
||||
GroupingFactor.Three,
|
||||
GroupingFactor.Four,
|
||||
],
|
||||
precisions=[Precision.M1C1, Precision.M2C2, Precision.M3C3, Precision.M4C4],
|
||||
associated_tables=[TABLE_BENCH_MULTIBIT_BY_PRECISION],
|
||||
),
|
||||
]
|
||||
|
||||
INTEGER_PARAM_CASES = [
|
||||
# ParametersFilterCase( # TODO Table 5, 6
|
||||
# "%PARAM_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}", # 1_1, 2_2, 4_4 (pfail: 2m64, 2m128)
|
||||
# pfails=[
|
||||
# ErrorFailureProbability.TWO_MINUS_64,
|
||||
# ErrorFailureProbability.TWO_MINUS_128,
|
||||
# ],
|
||||
# precisions=[Precision.M1C1, Precision.M2C2, Precision.M4C4],
|
||||
# associated_tables=[
|
||||
# TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL64,
|
||||
# TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL128,
|
||||
# ],
|
||||
# ),
|
||||
ParametersFilterCase( # TODO 8
|
||||
"%PARAM_MESSAGE_2_CARRY_2_{atomic_pattern}_GAUSSIAN_2M128",
|
||||
atomic_patterns=[
|
||||
AtomicPattern.KSPBS,
|
||||
AtomicPattern.KS32PBS,
|
||||
],
|
||||
associated_tables=[TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING128KS32],
|
||||
),
|
||||
ParametersFilterCase( # TODO Table 9, 10, 11, 12, NEED TO DEAL WITH OPERANDS (ct and plaintext)
|
||||
"%PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_{pfail}",
|
||||
pfails=[
|
||||
ErrorFailureProbability.TWO_MINUS_64,
|
||||
ErrorFailureProbability.TWO_MINUS_128,
|
||||
],
|
||||
associated_tables=BLAH,
|
||||
),
|
||||
# ParametersFilterCase( # TODO Table 7
|
||||
# "%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
|
||||
# pfails=[
|
||||
# ErrorFailureProbability.TWO_MINUS_64,
|
||||
# ErrorFailureProbability.TWO_MINUS_128,
|
||||
# ],
|
||||
# grouping_factors=[
|
||||
# GroupingFactor.Two,
|
||||
# GroupingFactor.Three,
|
||||
# GroupingFactor.Four,
|
||||
# ],
|
||||
# ),
|
||||
# ParametersFilterCase( # TODO Table 13
|
||||
# "%COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
|
||||
# pfails=[
|
||||
# ErrorFailureProbability.TWO_MINUS_64,
|
||||
# ErrorFailureProbability.TWO_MINUS_128,
|
||||
# ],
|
||||
# ),
|
||||
]
|
||||
|
||||
|
||||
def _generate_latex_tables(
|
||||
conn: connector.PostgreConnector,
|
||||
user_config: config.UserConfig,
|
||||
result_dir: pathlib.Path,
|
||||
):
|
||||
conversion_func = utils.convert_latency_value_to_readable_text
|
||||
|
||||
case_config = copy.deepcopy(user_config)
|
||||
case_config.backend = config.Backend.CPU
|
||||
case_config.pbs_kind = config.PBSKind.Any
|
||||
|
||||
# for case in CORE_CRYPTO_PARAM_CASES:
|
||||
# case_config.layer = config.Layer.CoreCrypto
|
||||
# param_patterns = case.get_parameter_variants()
|
||||
# res = conn.fetch_benchmark_data(case_config, param_name_patterns=param_patterns)
|
||||
#
|
||||
# generic_formatter = CoreFormatter(
|
||||
# case_config.layer,
|
||||
# case_config.backend,
|
||||
# case_config.pbs_kind,
|
||||
# case_config.grouping_factor,
|
||||
# )
|
||||
# formatted_results = generic_formatter.format_data_with_available_sizes(
|
||||
# res,
|
||||
# conversion_func,
|
||||
# )
|
||||
#
|
||||
# for k, v in formatted_results.items(): # DEBUG
|
||||
# print(k)
|
||||
# for sub_k, sub_v in v.items():
|
||||
# print(f"\t{sub_k}: {sub_v}")
|
||||
# print("")
|
||||
#
|
||||
# print("--------------------------------------------------")
|
||||
# print("--------------------------------------------------")
|
||||
# print("--------------------------------------------------")
|
||||
#
|
||||
# for table in case.associated_tables:
|
||||
# formatted_table = table.format_table(
|
||||
# formatted_results
|
||||
# )
|
||||
# print(formatted_table) # DEBUG
|
||||
|
||||
for case in INTEGER_PARAM_CASES:
|
||||
case_config.layer = config.Layer.Integer
|
||||
param_patterns = case.get_parameter_variants()
|
||||
res = conn.fetch_benchmark_data(case_config, param_name_patterns=param_patterns)
|
||||
|
||||
generic_formatter = IntegerFormatter(
|
||||
case_config.layer,
|
||||
case_config.backend,
|
||||
case_config.pbs_kind,
|
||||
case_config.grouping_factor,
|
||||
)
|
||||
formatted_results = generic_formatter.format_data_with_available_sizes(
|
||||
res,
|
||||
conversion_func,
|
||||
)
|
||||
|
||||
# FIXME il faut qu'on puisse avoir accès aux définitions de paramètres après formattage afin de pouvoir filtrer
|
||||
# ensuite sans quoi on ne peut pas faire de parameters matching
|
||||
for k, v in formatted_results.items(): # DEBUG
|
||||
print(k)
|
||||
for sub_k, sub_v in v.items():
|
||||
print(f"\t{sub_k}: {sub_v}")
|
||||
print("")
|
||||
|
||||
print("--------------------------------------------------")
|
||||
print("--------------------------------------------------")
|
||||
print("--------------------------------------------------")
|
||||
|
||||
for table in case.associated_tables:
|
||||
formatted_table = table.format_table(formatted_results)
|
||||
# TODO écrire chaque table dans un fichier en récupérant le __name__ de la table, le lower() et strip() le préfixe "table_"
|
||||
print(formatted_table) # DEBUG
|
||||
|
||||
# TODO prendre la valeur minimum dans un groupe d'opération (ex: min/max, gt/ge/lt/le, ...)
|
||||
|
||||
|
||||
def perform_latex_generation(
|
||||
conn: connector.PostgreConnector, user_config: config.UserConfig
|
||||
):
|
||||
dir_path = user_config.output_file
|
||||
try:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e: ## TODO find the exact exception that can be raised here
|
||||
pass
|
||||
|
||||
# TODO passer un dossier en user_config et enregistrer les tables latex dans ce dossier un fichier = une table
|
||||
_generate_latex_tables(conn, user_config, dir_path)
|
||||
Reference in New Issue
Block a user