WIP: implement tables cases - half done

This commit is contained in:
David Testé
2025-12-19 09:41:17 +01:00
parent f15e96fde3
commit 186ace4fb4
10 changed files with 1267 additions and 216 deletions

View File

@@ -50,6 +50,14 @@ class Layer(enum.StrEnum):
raise NotImplementedError(f"layer '{layer_name}' not supported")
class OperandSize(int):
"""
Syntactic sugar for operand sizes handled as integer.
"""
pass
class RustType(enum.Enum):
"""
Represents different integer Rust types used in tfhe-rs.
@@ -189,6 +197,26 @@ class OperandType(enum.StrEnum):
PlainText = "PlainText"
class AtomicPattern(enum.StrEnum):
KSPBS = "KS_PBS"
PBSKS = "PBS_KS"
KS32PBS = "KS32_PBS"
@staticmethod
def from_str(pattern_name):
match pattern_name.lower():
case "ks_pbs":
return AtomicPattern.KSPBS
case "pbs_ks":
return AtomicPattern.PBSKS
case "ks32_pbs":
return AtomicPattern.KS32PBS
case _:
raise NotImplementedError(
f"atomic pattern '{pattern_name}' not supported"
)
class PBSKind(enum.StrEnum):
"""
Represents the kind of parameter set used for Programmable Bootstrapping operation.
@@ -279,6 +307,71 @@ class ErrorFailureProbability(enum.IntEnum):
return self.to_str()
class GroupingFactor(enum.IntEnum):
Two = 2
Three = 3
Four = 4
@staticmethod
def from_str(gf_value):
try:
int_value = int(gf_value)
except ValueError:
raise ValueError(f"grouping factor '{gf_value}' is not an integer")
match int_value:
case 2:
return GroupingFactor.Two
case 3:
return GroupingFactor.Three
case 4:
return GroupingFactor.Four
case _:
raise NotImplementedError(f"grouping factor '{gf_value}' not supported")
class Precision(enum.Enum):
M1C1 = (1, 1)
M2C2 = (2, 2)
M3C3 = (3, 3)
M4C4 = (4, 4)
@staticmethod
def from_param_name(name: str) -> enum.Enum:
parts = name.split("_")
message = None
carry = None
for i, part in enumerate(parts):
if part == "MESSAGE":
message = int(parts[i + 1])
elif part == "CARRY":
carry = int(parts[i + 1])
if message is None or carry is None:
raise ValueError(f"could not extract precision in '{name}'")
match (message, carry):
case (1, 1):
return Precision.M1C1
case (2, 2):
return Precision.M2C2
case (3, 3):
return Precision.M3C3
case (4, 4):
return Precision.M4C4
case _:
raise NotImplementedError(
f"precision message={message} carry={carry} not supported yet"
)
def message(self) -> int:
return self.value[0]
def carry(self) -> int:
return self.value[1]
class BenchType(enum.Enum):
Latency = 0
Throughput = 1
@@ -309,8 +402,7 @@ class ParamsDefinition:
"""
def __init__(self, param_name: str):
self.message_size = None
self.carry_size = None
self.precision = None
self.pbs_kind = None
self.grouping_factor = None
self.noise_distribution = None
@@ -323,8 +415,8 @@ class ParamsDefinition:
def __eq__(self, other):
return (
self.message_size == other.message_size
and self.carry_size == other.carry_size
self.precision.message() == other.precision.message()
and self.precision.carry() == other.precision.carry()
and self.pbs_kind == other.pbs_kind
and self.grouping_factor == other.grouping_factor
and self.noise_distribution == other.noise_distribution
@@ -337,16 +429,16 @@ class ParamsDefinition:
def __lt__(self, other):
return (
self.message_size < other.message_size
and self.carry_size < other.carry_size
self.precision.message() < other.precision.message()
and self.precision.carry() < other.precision.carry()
and self.p_fail < other.p_fail
)
def __hash__(self):
return hash(
(
self.message_size,
self.carry_size,
self.precision.message,
self.precision.carry,
self.pbs_kind,
self.grouping_factor,
self.noise_distribution,
@@ -357,7 +449,47 @@ class ParamsDefinition:
)
def __repr__(self):
return f"ParamsDefinition(message_size={self.message_size}, carry_size={self.carry_size}, pbs_kind={self.pbs_kind}, grouping_factor={self.grouping_factor}, noise_distribution={self.noise_distribution}, atomic_pattern={self.atomic_pattern}, p_fail={self.p_fail}, version={self.version}, details={self.details})"
return (
f"ParamsDefinition("
f"message_size={self.precision.message()}, "
f"carry_size={self.precision.carry()}, "
# f"pbs_kind={self.pbs_kind}, "
f"grouping_factor={self.grouping_factor}, "
# f"noise_distribution={self.noise_distribution}, "
f"atomic_pattern={self.atomic_pattern}, "
f"p_fail={self.p_fail}, "
# f"version={self.version}, "
# f"details={self.details})"
)
def components_match(self, *components):
for component in components:
match component:
case Precision():
if self.precision != component:
return False
case PBSKind():
if self.pbs_kind != component:
return False
case GroupingFactor():
if self.grouping_factor != component:
return False
case NoiseDistribution():
if self.noise_distribution != component:
return False
case AtomicPattern():
if self.atomic_pattern != component:
return False
case ErrorFailureProbability():
if self.p_fail != component:
return False
case _:
raise ValueError(
f"unsupported component type for matching '{type(component)}'"
)
# Each component matches
return True
def _parse_param_name(self, param_name: str) -> None:
split_params = param_name.split("_")
@@ -376,6 +508,12 @@ class ParamsDefinition:
params_variation_parts.append(part)
try:
self.precision = Precision.from_param_name(param_name)
except ValueError or NotImplementedError:
# Might be a Boolean parameters set
raise ParametersFormatNotSupported(param_name)
try:
self.p_fail = ErrorFailureProbability.from_param_name(param_name)
pfail_index = split_params.index(self.p_fail.to_str())
@@ -395,23 +533,17 @@ class ParamsDefinition:
noise_distribution_index = None
try:
self.message_size = int(split_params[split_params.index("MESSAGE") + 1])
carry_size_index = split_params.index("CARRY") + 1
self.carry_size = int(split_params[carry_size_index])
self.atomic_pattern = "_".join(
split_params[carry_size_index + 1 : noise_distribution_index]
)
except ValueError:
# Might be a Boolean parameters set
raise ParametersFormatNotSupported(param_name)
try:
if noise_distribution_index:
self.atomic_pattern = "_".join(
split_params[carry_size_index + 1 : noise_distribution_index]
self.atomic_pattern = AtomicPattern.from_str(
"_".join(
split_params[carry_size_index + 1 : noise_distribution_index]
)
)
else:
self.atomic_pattern = "_".join(split_params[carry_size_index + 1 :])
self.atomic_pattern = AtomicPattern.from_str(
"_".join(split_params[carry_size_index + 1 :])
)
except ValueError:
# Might be a Boolean parameters set
raise ParametersFormatNotSupported(param_name)
@@ -424,7 +556,9 @@ class ParamsDefinition:
try:
# This is a multi-bit parameters set
self.grouping_factor = int(split_params[split_params.index("GROUP") + 1])
self.grouping_factor = GroupingFactor.from_str(
split_params[split_params.index("GROUP") + 1]
)
self.pbs_kind = PBSKind.MultiBit
except ValueError:
# This is a classical parameters set
@@ -448,7 +582,7 @@ class BenchDetails:
:type bit_size: int
"""
def __init__(self, layer: Layer, bench_full_name: str, bit_size: int):
def __init__(self, layer: Layer, bench_full_name: str, bit_size: OperandSize):
self.layer = layer
self.operation_name = None

View File

@@ -9,6 +9,7 @@ from benchmark_specs import (
BenchDetails,
BenchType,
Layer,
OperandSize,
OperandType,
PBSKind,
)
@@ -302,7 +303,7 @@ class PostgreConnector:
raise NoDataFound(msg)
for line in lines:
bit_width = line[1]
bit_width = OperandSize(line[1])
bench_details = BenchDetails(layer, line[0], bit_width)
value = line[-1] if last_value_only else line[-3]

View File

@@ -25,9 +25,9 @@ import formatters.core
import formatters.hlapi
import formatters.integer
import regression
import whitepaper
from benchmark_specs import BenchType, Layer, OperandType, RustType
from formatters.common import BenchArray, CSVFormatter, MarkdownFormatter, SVGFormatter
from whitepaper import whitepaper
import utils
@@ -442,7 +442,7 @@ if __name__ == "__main__":
if args.generate_whitepaper_latex:
try:
whitepaper.perform_latex_generation(conn, user_config)
except RuntimeError as err:
except Exception as err:
print(f"Failed to perform whitepaper LaTex tables generation: {err}")
sys.exit(2)
else:

View File

@@ -1,7 +1,9 @@
import collections
import enum
import pathlib
import xml.dom.minidom
from collections.abc import Callable
from typing import Any
import svg
from benchmark_specs import (
@@ -12,6 +14,7 @@ from benchmark_specs import (
ErrorFailureProbability,
Layer,
NoiseDistribution,
OperandSize,
OperandType,
PBSKind,
RustType,
@@ -189,6 +192,16 @@ class GenericFormatter:
# Must be implemented by subclasses
raise NotImplementedError
def format_data_with_available_sizes(
self, data: dict[BenchDetails : list[int]], conversion_func
):
return self._format_data_with_available_sizes(data, conversion_func)
@staticmethod
def _format_data_with_available_sizes(*args, **kwargs):
# Must be implemented by subclasses
raise NotImplementedError
def generate_array(
self,
data,
@@ -659,11 +672,170 @@ class SVGFormatter(GenericFormatter):
return self.generate_svg_table(array)
class LatexColumn:
pass
# TODO faire une sous classe de certaines enum avec juste une méthode to_latex_str() pour avoir la valeur qui correspond a celle affichée dans le whitepaper
class LatexRow:
pass
class ElementType(enum.Enum):
Operation = 0
ParamComponent = 1
SizeComponent = 2
class LatexSeparator(enum.StrEnum):
BottomRule = r"\bottomrule"
MidRule = r"\midrule"
HLine = r"\hline"
class LatexElement:
def __init__(
self,
elem: Any,
elem_type: ElementType,
latex_str: str,
display_element: bool = True,
):
self.elem = elem
self.type = elem_type
self.latex_str = latex_str
self.display_element = display_element
def format(self):
return self.latex_str
class LatexRowElement(LatexElement):
def __init__(
self, *args, display_as_column: bool = False, column_span: int = None, **kwargs
):
super().__init__(*args, **kwargs)
self.display_as_column = display_as_column
self.column_span = column_span or 1
def format(self):
if self.display_as_column:
return "\n".join(
[
f"{LatexSeparator.MidRule}",
f"\\multicolumn{{{self.column_span}}}{{c}}{{\\textbf{{{self.latex_str}}}}} \\\\",
f"{LatexSeparator.MidRule}",
]
)
else:
return self.latex_str
class LatexColumnElement(LatexElement):
def __init__(self, *args, sub_cols=None, **kwargs):
super().__init__(*args, **kwargs)
self.sub_columns: list[LatexColumnElement] = sub_cols or []
class LatexArraySection:
def __init__(self, rows: list):
self.rows = rows
# TODO intégrer ici la notion de mise en gras du/des résultats les plus petits (enum ?)
# Ca peut être le minimum sur une ligne (ou un partie de ligne, cf multi-bit core)
# ou bien le minimum sur une section entière (cf classic core)
def format(self, bench_data):
lines = []
# IMPORTANT: data must be sorted with row discriminant first and then column discriminant that would give access to the value
for row in self.rows:
try:
lines.append(self._format_row(row, bench_data))
except TypeError:
# Row is a single LatexRowElement
lines.append(row.format())
print("Lines:", lines) # DEBUG
return lines
@staticmethod
def _format_row(row, bench_data):
operation = None
row_filters = []
row_values = []
for elem in row:
if isinstance(elem, LatexRowElement):
match elem.type:
case ElementType.Operation:
operation = elem.elem
case ElementType.ParamComponent:
row_filters.append(elem.elem)
if elem.display_element:
row_values.append(elem.format())
elif isinstance(elem, LatexColumnElement):
bench_values = bench_data.get(operation)
if elem.sub_columns:
for sub_col in elem.sub_columns:
for param_def, value in bench_values.items():
if isinstance(sub_col.elem, OperandSize):
bit_size = sub_col.elem
if param_def.components_match(*row_filters, elem.elem):
try:
row_values.append(value[bit_size])
except KeyError:
# TODO maybe returning a "N/A" value would be better than raising an error
raise KeyError(
f"bit size '{bit_size}' not found in bench data (params: {param_def}, values: {value})"
)
else:
if param_def.components_match(
*row_filters, elem.elem, sub_col.elem
):
# Only one value is stored in the operand size dict.
_, v = value.popitem()
row_values.append(v)
else:
for param_def, value in bench_values.items():
if isinstance(
value, dict
): # TODO If(condition) et else(content) à supprimer
# Handling data from the integer tfhe-rs layer.
bit_size = elem.elem
if param_def.components_match(*row_filters):
try:
row_values.append(value[bit_size])
except KeyError:
# TODO maybe returning a "N/A" value would be better than raising an error
raise KeyError(
f"bit size '{bit_size}' not found in bench data (params: {param_def}, values: {value})"
)
else:
# Handling data from the core-crypto tfhe-rs layer.
if param_def.components_match(*row_filters, elem.elem):
row_values.append(value)
print("Row values:", row_values)
return " & ".join(row_values) + r" \\"
class LatexTable:
def __init__(
self,
array_section: list[LatexArraySection],
):
self.array_sections = array_section
@staticmethod
def get_separator(sep: LatexSeparator):
return sep.value
def format_table(self, bench_data):
parts = []
for section in self.array_sections:
parts.extend(section.format(bench_data))
parts.append(self.get_separator(LatexSeparator.HLine))
# Replace the last section separator with a bottom rule
parts[-1] = self.get_separator(LatexSeparator.BottomRule)
return "\n".join(parts)
class LatexFormatter(GenericFormatter):

View File

@@ -88,6 +88,22 @@ class CoreFormatter(GenericFormatter):
return formatted
@staticmethod
def _format_data_with_available_sizes(
data: dict[BenchDetails : list[int]], conversion_func
):
formatted = collections.defaultdict(lambda: collections.defaultdict(lambda: {}))
for details, timings in data.items():
reduced_params = details.get_params_definition()
test_name = details.operation_name
bit_width = details.bit_size
value = conversion_func(timings[-1])
formatted[test_name][reduced_params][bit_width] = value
return formatted
def _generate_arrays(
self,
data,

View File

@@ -56,6 +56,22 @@ class IntegerFormatter(GenericFormatter):
return formatted
@staticmethod
def _format_data_with_available_sizes(
data: dict[BenchDetails : list[int]], conversion_func
):
formatted = collections.defaultdict(lambda: collections.defaultdict(lambda: {}))
for details, timings in data.items():
reduced_params = details.get_params_definition()
test_name = "_".join((details.sign_flavor.value, details.operation_name))
bit_width = details.bit_size
value = conversion_func(timings[-1])
formatted[test_name][reduced_params][bit_width] = value
return formatted
def _generate_arrays(
self,
data,

View File

@@ -1,185 +0,0 @@
import copy
import itertools
import pathlib
import sys
import config
import connector
from benchmark_specs import ErrorFailureProbability
from formatters.core import CoreFormatter
import utils
class Default(dict):
def __missing__(self, key):
return f"{{{key}}}"
class ParametersFilterCase:
def __init__(
self,
param_name_pattern: str,
pfails: list[ErrorFailureProbability] = None,
grouping_factors: list[int] = None,
message_carry_sizes: list[int] = None,
):
self.param_name_pattern = param_name_pattern
self.pfails = pfails or []
self.grouping_factors = grouping_factors or []
self.message_carry_sizes = message_carry_sizes or []
def get_parameter_variants(self):
after_pfails = []
for pfail in self.pfails:
after_pfails.append(
self.param_name_pattern.format_map(Default(pfail=pfail.to_str()))
)
after_grouping_factors = []
if after_pfails:
for name in after_pfails:
for gf in self.grouping_factors:
after_grouping_factors.append(name.format_map(Default(gf=gf)))
else:
for gf in self.grouping_factors:
after_grouping_factors.append(
self.grouping_factors.format_map(Default(gf=gf))
)
after_msg_carry_sizes = []
if after_grouping_factors:
for name in after_grouping_factors:
for size in self.message_carry_sizes:
after_msg_carry_sizes.append(
name.format_map(Default(msg=size, carry=size))
)
else:
for size in self.message_carry_sizes:
after_msg_carry_sizes.append(
self.param_name_pattern.format_map(Default(msg=size, carry=size))
)
return (
after_msg_carry_sizes
or after_grouping_factors
or after_pfails
or [self.param_name_pattern]
)
CORE_CRYPTO_PARAM_CASES = [
ParametersFilterCase(
"%PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
),
ParametersFilterCase(
"%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
grouping_factors=[2, 3, 4],
message_carry_sizes=[1, 2, 3, 4],
),
]
INTEGER_PARAM_CASES = [
ParametersFilterCase(
"%PARAM_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}", # 1_1, 2_2, 4_4 (pfail: 2m64, 2m128)
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
message_carry_sizes=[1, 2, 4],
),
ParametersFilterCase(
"%PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
),
ParametersFilterCase(
"%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
grouping_factors=[2, 3, 4],
),
ParametersFilterCase(
"%COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
),
]
INTEGER_SPECIAL_CASE_OPERATIONS_FILTER = [
"add_parallelized",
"mul_parallelized",
"bitand_parallelized",
]
def _generate_latex_tables(
conn: connector.PostgreConnector,
user_config: config.UserConfig,
result_dir: pathlib.Path,
):
conversion_func = utils.convert_latency_value_to_readable_text
case_config = copy.deepcopy(user_config)
case_config.backend = config.Backend.CPU
case_config.pbs_kind = config.PBSKind.Any
for case in CORE_CRYPTO_PARAM_CASES:
case_config.layer = config.Layer.CoreCrypto
param_patterns = case.get_parameter_variants()
res = conn.fetch_benchmark_data(case_config, param_name_patterns=param_patterns)
generic_formatter = CoreFormatter(
case_config.layer,
case_config.backend,
case_config.pbs_kind,
case_config.grouping_factor,
)
formatted_results = generic_formatter.format_data(
res,
conversion_func,
)
# TODO créer le tbaleau qui va bien en fonction du cas
for r in formatted_results.items(): # DEBUG
print(r)
print("--------------------------------------------------")
print("--------------------------------------------------")
print("--------------------------------------------------")
# TODO prendre la valeur minimum dans un groupe d'opération (ex: min/max, gt/ge/lt/le, ...)
def perform_latex_generation(
conn: connector.PostgreConnector, user_config: config.UserConfig
):
dir_path = user_config.output_file
try:
dir_path.mkdir(parents=True, exist_ok=True)
except Exception as e: ## TODO find the exact exception that can be raised here
pass
# TODO passer un dossier en user_config et enregistrer les tables latex dans ce dossier un fichier = une table
_generate_latex_tables(conn, user_config, dir_path)
# Table 2 is core crypto KS/KS-PBS/PBS for 2_2_128 and 2_2_64
# Table 3 is core crypto KS/KS-PBS/PBS in the multibit case for V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y in [1,2,3,4] and Z in [64, 128]
# Table 5 is special case with 1_1_64, 2_2_64, 4_4_64
# Table 6 is special case with 1_1_128, 2_2_128, 4_4_128
# Table 7 is the multibit special case with V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y = 2, Z = 64.
# Table 8 is special case for 2_2_128 and 2_2_128 KS_32 <--- changed
# Table 9, 10, 11, 12 are the integer ops for 2_2_64_KS32 and 2_2_128_KS32
# Table 13 is compression (which I believe was in special case?) for the 64/128 compression parameters

View File

@@ -0,0 +1 @@
from . import whitepaper

View File

@@ -0,0 +1,627 @@
"""
All tables are named after their label in the whitepaper.
"""
from benchmark_specs import (
AtomicPattern,
ErrorFailureProbability,
GroupingFactor,
OperandSize,
Precision,
)
from formatters.common import (
ElementType,
LatexArraySection,
LatexColumnElement,
LatexRowElement,
LatexTable,
)
# Table 2 is core crypto KS/KS-PBS/PBS for 2_2_128 and 2_2_64
# Table 3 is core crypto KS/KS-PBS/PBS in the multibit case for V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y in [1,2,3,4] and Z in [64, 128]
# Table 5 is special case with 1_1_64, 2_2_64, 4_4_64
# Table 6 is special case with 1_1_128, 2_2_128, 4_4_128
# Table 7 is the multibit special case with V1_4_PARAM_MULTI_BIT_GROUP_X_MESSAGE_Y_CARRY_Y_KS_PBS_GAUSSIAN_2MZ for X in [2,3,4], Y = 2, Z = 64.
# Table 8 is special case for 2_2_128 and 2_2_128 KS_32 <--- changed
# Table 9, 10, 11, 12 are the integer ops for 2_2_64_KS32 and 2_2_128_KS32
# Table 13 is compression (which I believe was in special case?) for the 64/128 compression parameters
# ------------------
# LaTex row elements
# ------------------
def _get_operation_elem(
operation_name: str,
latex_str: str,
display_element: bool = True,
display_as_column: bool = False,
column_span: int = 1,
) -> LatexRowElement:
return LatexRowElement(
operation_name,
ElementType.Operation,
latex_str,
display_element=display_element,
display_as_column=display_as_column,
column_span=column_span,
)
PFAIL_2M64_ELEM = LatexRowElement(
ErrorFailureProbability.TWO_MINUS_64, ElementType.ParamComponent, r"\(2^{-64}\)"
)
PFAIL_2M128_ELEM = LatexRowElement(
ErrorFailureProbability.TWO_MINUS_128, ElementType.ParamComponent, r"\(2^{-128}\)"
)
PFAIL_2M64_HIDDEN_ELEM = LatexRowElement(
ErrorFailureProbability.TWO_MINUS_64,
ElementType.ParamComponent,
r"\(2^{-64}\)",
display_element=False,
)
PFAIL_2M128_HIDDEN_ELEM = LatexRowElement(
ErrorFailureProbability.TWO_MINUS_128,
ElementType.ParamComponent,
r"\(2^{-128}\)",
display_element=False,
)
M1C1_PFAIL_64_ELEM = LatexRowElement(
Precision.M1C1,
ElementType.ParamComponent,
r"${\tt 1\_1\_64}$",
)
M2C2_PFAIL_64_ELEM = LatexRowElement(
Precision.M2C2,
ElementType.ParamComponent,
r"${\tt 2\_2\_64}$",
)
M4C4_PFAIL_64_ELEM = LatexRowElement(
Precision.M4C4,
ElementType.ParamComponent,
r"${\tt 4\_4\_64}$",
)
M1C1_PFAIL_128_ELEM = LatexRowElement(
Precision.M1C1,
ElementType.ParamComponent,
r"${\tt 1\_1\_128}$",
)
M2C2_PFAIL_128_ELEM = LatexRowElement(
Precision.M2C2,
ElementType.ParamComponent,
r"${\tt 2\_2\_128}$",
)
M2C2_PFAIL_128_KS32_ELEM = LatexRowElement(
Precision.M2C2,
ElementType.ParamComponent,
r"${\tt 2\_2\_128\_KS32}$",
)
M4C4_PFAIL_128_ELEM = LatexRowElement(
Precision.M4C4,
ElementType.ParamComponent,
r"${\tt 4\_4\_128}$",
)
KSPBS_HIDDEN_ELEM = LatexRowElement(
AtomicPattern.KSPBS,
ElementType.ParamComponent,
"",
display_element=False,
)
KS32PBS_HIDDEN_ELEM = LatexRowElement(
AtomicPattern.KS32PBS,
ElementType.ParamComponent,
"",
display_element=False,
)
KS_OP_ELEM = _get_operation_elem("keyswitch", r"\ks")
PBS_OP_ELEM = _get_operation_elem("pbs_mem_optimized", r"\pbs")
MB_PBS_OP_ELEM = _get_operation_elem("multi_bit_deterministic_pbs", r"\mbpbs")
KSPBS_OP_ELEM = _get_operation_elem("ks_pbs", r"\kspbs")
KS_MB_PBS_OP_ELEM = _get_operation_elem("multi_bit_deterministic_ks_pbs", r"\ksmbpbs")
ADD_OP_HIDDEN_ELEM = _get_operation_elem(
"unsigned_add_parallelized", "Addition", display_element=False
)
BITAND_OP_HIDDEN_ELEM = _get_operation_elem(
"unsigned_bitand_parallelized", "Bitwise AND", display_element=False
)
MUL_OP_HIDDEN_ELEM = _get_operation_elem(
"unsigned_mul_parallelized", "Multiplication", display_element=False
)
# ---------------------
# LaTex column elements
# ---------------------
def _get_precision_column_element(
precision: Precision, sub_cols: list[LatexColumnElement] = None
) -> LatexColumnElement:
return LatexColumnElement(
precision, ElementType.ParamComponent, "", sub_cols=sub_cols
)
def _get_grouping_factor_column_element(
grouping_factor: GroupingFactor,
) -> LatexColumnElement:
return LatexColumnElement(
grouping_factor, ElementType.ParamComponent, str(grouping_factor)
)
def _get_operand_size_column_element(
op_size: int,
) -> LatexColumnElement:
return LatexColumnElement(
OperandSize(op_size), ElementType.SizeComponent, str(op_size)
)
# Operand size is set to the value of the message size since it's stored as is in the database.
M1C1_COL_ELEM = _get_precision_column_element(
Precision.M1C1, sub_cols=[_get_operand_size_column_element(1)]
)
M2C2_COL_ELEM = _get_precision_column_element(
Precision.M2C2, sub_cols=[_get_operand_size_column_element(2)]
)
M3C3_COL_ELEM = _get_precision_column_element(
Precision.M3C3, sub_cols=[_get_operand_size_column_element(3)]
)
M4C4_COL_ELEM = _get_precision_column_element(
Precision.M4C4, sub_cols=[_get_operand_size_column_element(4)]
)
ALL_GROUPING_FACTORS_ELEM = [
_get_grouping_factor_column_element(GroupingFactor.Two),
_get_grouping_factor_column_element(GroupingFactor.Three),
_get_grouping_factor_column_element(GroupingFactor.Four),
]
M1C1_ALL_GF_COL_ELEM = _get_precision_column_element(
Precision.M1C1, sub_cols=ALL_GROUPING_FACTORS_ELEM
)
M2C2_ALL_GF_COL_ELEM = _get_precision_column_element(
Precision.M2C2, sub_cols=ALL_GROUPING_FACTORS_ELEM
)
M3C3_ALL_GF_COL_ELEM = _get_precision_column_element(
Precision.M3C3, sub_cols=ALL_GROUPING_FACTORS_ELEM
)
M4C4_ALL_GF_COL_ELEM = _get_precision_column_element(
Precision.M4C4, sub_cols=ALL_GROUPING_FACTORS_ELEM
)
ALL_OPERAND_SIZES_ELEM = [
_get_operand_size_column_element(4),
_get_operand_size_column_element(8),
_get_operand_size_column_element(16),
_get_operand_size_column_element(32),
_get_operand_size_column_element(64),
_get_operand_size_column_element(128),
_get_operand_size_column_element(256),
]
OP_AS_COL_SPAN = len(ALL_OPERAND_SIZES_ELEM) + 1
# Operations used in integer special case tables
ADD_OP_AS_COL_ELEM = _get_operation_elem(
"unsigned_add_parallelized",
"Addition",
display_as_column=True,
column_span=OP_AS_COL_SPAN,
)
BITAND_OP_AS_COL_ELEM = _get_operation_elem(
"unsigned_bitand_parallelized",
"Bitwise AND",
display_as_column=True,
column_span=OP_AS_COL_SPAN,
)
MUL_OP_AS_COL_ELEM = _get_operation_elem(
"unsigned_mul_parallelized",
"Multiplication",
display_as_column=True,
column_span=OP_AS_COL_SPAN,
)
# TODO On a besoin de garder quelque part la raw_value du bench (après conversion str) pour effectuer des comparaisons et trouver le minimum sur un groupe de résultats
# TODO Calculer les lignes "amortized" pour TABLE_2
# -----------------------
# LaTex table definitions
# -----------------------
TABLE_PBS_BENCH = LatexTable(
[
LatexArraySection(
[
[
PFAIL_2M64_ELEM,
KS_OP_ELEM,
M1C1_COL_ELEM,
M2C2_COL_ELEM,
M3C3_COL_ELEM,
M4C4_COL_ELEM,
],
[
PFAIL_2M64_ELEM,
PBS_OP_ELEM,
M1C1_COL_ELEM,
M2C2_COL_ELEM,
M3C3_COL_ELEM,
M4C4_COL_ELEM,
],
[
PFAIL_2M64_ELEM,
KSPBS_OP_ELEM,
M1C1_COL_ELEM,
M2C2_COL_ELEM,
M3C3_COL_ELEM,
M4C4_COL_ELEM,
],
# [PFAIL_2M64_ELEM, KSPBS_OP_ELEM_AMORTIZED, M1C1_COL_ELEM, M2C2_COL_ELEM, M3C3_COL_ELEM, M4C4_COL_ELEM], # TODO line de données à calculer depuis les résultats
]
),
LatexArraySection(
[
[
PFAIL_2M128_ELEM,
KS_OP_ELEM,
M1C1_COL_ELEM,
M2C2_COL_ELEM,
M3C3_COL_ELEM,
M4C4_COL_ELEM,
],
[
PFAIL_2M128_ELEM,
PBS_OP_ELEM,
M1C1_COL_ELEM,
M2C2_COL_ELEM,
M3C3_COL_ELEM,
M4C4_COL_ELEM,
],
[
PFAIL_2M128_ELEM,
KSPBS_OP_ELEM,
M1C1_COL_ELEM,
M2C2_COL_ELEM,
M3C3_COL_ELEM,
M4C4_COL_ELEM,
],
# [PFAIL_2M128_ELEM, KSPBS_OP_ELEM_AMORTIZED, M1C1_COL_ELEM, M2C2_COL_ELEM, Precision.M3C3, M4C4_COL_ELEM],
]
),
]
)
TABLE_BENCH_MULTIBIT_BY_PRECISION = LatexTable(
[
LatexArraySection(
[
[
PFAIL_2M64_ELEM,
KS_OP_ELEM,
M1C1_ALL_GF_COL_ELEM,
M2C2_ALL_GF_COL_ELEM,
M3C3_ALL_GF_COL_ELEM,
M4C4_ALL_GF_COL_ELEM,
],
[
PFAIL_2M64_ELEM,
MB_PBS_OP_ELEM,
M1C1_ALL_GF_COL_ELEM,
M2C2_ALL_GF_COL_ELEM,
M3C3_ALL_GF_COL_ELEM,
M4C4_ALL_GF_COL_ELEM,
],
[
PFAIL_2M64_ELEM,
KS_MB_PBS_OP_ELEM,
M1C1_ALL_GF_COL_ELEM,
M2C2_ALL_GF_COL_ELEM,
M3C3_ALL_GF_COL_ELEM,
M4C4_ALL_GF_COL_ELEM,
],
]
),
LatexArraySection(
[
[
PFAIL_2M128_ELEM,
KS_OP_ELEM,
M1C1_ALL_GF_COL_ELEM,
M2C2_ALL_GF_COL_ELEM,
M3C3_ALL_GF_COL_ELEM,
M4C4_ALL_GF_COL_ELEM,
],
[
PFAIL_2M128_ELEM,
MB_PBS_OP_ELEM,
M1C1_ALL_GF_COL_ELEM,
M2C2_ALL_GF_COL_ELEM,
M3C3_ALL_GF_COL_ELEM,
M4C4_ALL_GF_COL_ELEM,
],
[
PFAIL_2M128_ELEM,
KS_MB_PBS_OP_ELEM,
M1C1_ALL_GF_COL_ELEM,
M2C2_ALL_GF_COL_ELEM,
M3C3_ALL_GF_COL_ELEM,
M4C4_ALL_GF_COL_ELEM,
],
]
),
]
)
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL64 = LatexTable(
[
LatexArraySection(
[
BITAND_OP_AS_COL_ELEM,
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M1C1_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M2C2_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M4C4_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
ADD_OP_AS_COL_ELEM,
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M1C1_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M2C2_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M4C4_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
MUL_OP_AS_COL_ELEM,
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M1C1_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M2C2_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M4C4_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
]
),
]
)
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL128 = LatexTable(
[
LatexArraySection(
[
BITAND_OP_AS_COL_ELEM,
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M1C1_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M2C2_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M4C4_PFAIL_64_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
ADD_OP_AS_COL_ELEM,
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M1C1_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M4C4_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
MUL_OP_AS_COL_ELEM,
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M1C1_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M64_HIDDEN_ELEM,
M4C4_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
]
),
]
)
TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING = LatexTable(
[
LatexArraySection(
[
BITAND_OP_AS_COL_ELEM,
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M1C1_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M4C4_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
ADD_OP_AS_COL_ELEM,
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M1C1_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M4C4_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
MUL_OP_AS_COL_ELEM,
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M1C1_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
M4C4_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
]
),
]
)
TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING128KS32 = LatexTable(
[
LatexArraySection(
[
BITAND_OP_AS_COL_ELEM,
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
KSPBS_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
BITAND_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
KS32PBS_HIDDEN_ELEM,
M2C2_PFAIL_128_KS32_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
ADD_OP_AS_COL_ELEM,
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
KSPBS_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
ADD_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
KS32PBS_HIDDEN_ELEM,
M2C2_PFAIL_128_KS32_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
MUL_OP_AS_COL_ELEM,
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
KSPBS_HIDDEN_ELEM,
M2C2_PFAIL_128_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
[
MUL_OP_HIDDEN_ELEM,
PFAIL_2M128_HIDDEN_ELEM,
KS32PBS_HIDDEN_ELEM,
M2C2_PFAIL_128_KS32_ELEM,
*ALL_OPERAND_SIZES_ELEM,
],
]
),
]
)
# # No LaTex set for this table.
# table_plaintext_ciphertext_ops_pfail64_ks32 = LatexTable()
#
# # No LaTex set for this table.
# table_plaintext_ciphertext_ops_pfail128_ks32 = LatexTable()
#
# # No LaTex set for this table.
# table_ciphertext_ciphertext_ops_pfail64_ks32 = LatexTable()
#
# # No LaTex set for this table.
# table_ciphertext_ciphertext_ops_pfail128_ks32 = LatexTable()
#
# table_compression_benchmarks = LatexTable()

View File

@@ -0,0 +1,269 @@
import copy
import itertools
import pathlib
import sys
import utils
import config, connector
from benchmark_specs import (
AtomicPattern,
ErrorFailureProbability,
GroupingFactor,
Precision,
)
from formatters.common import (
LatexTable,
)
from formatters.core import CoreFormatter
from formatters.integer import IntegerFormatter
from .tables import (
TABLE_BENCH_MULTIBIT_BY_PRECISION,
TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING128KS32,
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL64,
TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL128,
TABLE_PBS_BENCH,
)
class Default(dict):
def __missing__(self, key):
return f"{{{key}}}"
class ParametersFilterCase:
def __init__(
self,
param_name_pattern: str,
pfails: list[ErrorFailureProbability] = None,
grouping_factors: list[GroupingFactor] = None,
precisions: list[Precision] = None,
atomic_patterns: list[AtomicPattern] = None,
associated_tables: list[LatexTable] = None,
):
self.param_name_pattern = param_name_pattern
self.pfails = pfails or []
self.grouping_factors = grouping_factors or []
self.precisions = precisions or []
self.atomic_patterns = atomic_patterns or []
self.associated_tables = associated_tables or []
def get_parameter_variants(self):
after_pfails = []
for pfail in self.pfails:
after_pfails.append(
self.param_name_pattern.format_map(Default(pfail=pfail.to_str()))
)
after_grouping_factors = []
if after_pfails:
for name in after_pfails:
for gf in self.grouping_factors:
after_grouping_factors.append(name.format_map(Default(gf=gf)))
else:
for gf in self.grouping_factors:
after_grouping_factors.append(
self.grouping_factors.format_map(Default(gf=gf))
)
last_populated = after_grouping_factors or after_pfails
after_msg_carry_sizes = []
if last_populated:
for name in last_populated:
for p in self.precisions:
after_msg_carry_sizes.append(
name.format_map(Default(msg=p.message(), carry=p.carry()))
)
else:
for p in self.precisions:
after_msg_carry_sizes.append(
self.param_name_pattern.format_map(
Default(msg=p.message(), carry=p.carry())
)
)
last_populated = last_populated or after_msg_carry_sizes
after_atomic_patterns = []
if last_populated:
for name in last_populated:
for a in self.atomic_patterns:
after_atomic_patterns.append(
name.format_map(Default(atomic_pattern=a))
)
else:
for a in self.atomic_patterns:
after_atomic_patterns.append(
self.param_name_pattern.format_map(Default(atomic_pattern=a))
)
return last_populated or after_atomic_patterns or [self.param_name_pattern]
CORE_CRYPTO_PARAM_CASES = [
ParametersFilterCase(
"%PARAM_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
precisions=[Precision.M1C1, Precision.M2C2, Precision.M3C3, Precision.M4C4],
associated_tables=[TABLE_PBS_BENCH],
),
ParametersFilterCase(
"%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
grouping_factors=[
GroupingFactor.Two,
GroupingFactor.Three,
GroupingFactor.Four,
],
precisions=[Precision.M1C1, Precision.M2C2, Precision.M3C3, Precision.M4C4],
associated_tables=[TABLE_BENCH_MULTIBIT_BY_PRECISION],
),
]
INTEGER_PARAM_CASES = [
# ParametersFilterCase( # TODO Table 5, 6
# "%PARAM_MESSAGE_{msg}_CARRY_{carry}_KS_PBS_GAUSSIAN_{pfail}", # 1_1, 2_2, 4_4 (pfail: 2m64, 2m128)
# pfails=[
# ErrorFailureProbability.TWO_MINUS_64,
# ErrorFailureProbability.TWO_MINUS_128,
# ],
# precisions=[Precision.M1C1, Precision.M2C2, Precision.M4C4],
# associated_tables=[
# TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL64,
# TABLE_COMPARISON_OPERATIONS_PRECISION_PFAIL128,
# ],
# ),
ParametersFilterCase( # TODO 8
"%PARAM_MESSAGE_2_CARRY_2_{atomic_pattern}_GAUSSIAN_2M128",
atomic_patterns=[
AtomicPattern.KSPBS,
AtomicPattern.KS32PBS,
],
associated_tables=[TABLE_COMPARISON_OPERATIONS_BOOTSTRAPPING128KS32],
),
ParametersFilterCase( # TODO Table 9, 10, 11, 12, NEED TO DEAL WITH OPERANDS (ct and plaintext)
"%PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_{pfail}",
pfails=[
ErrorFailureProbability.TWO_MINUS_64,
ErrorFailureProbability.TWO_MINUS_128,
],
associated_tables=BLAH,
),
# ParametersFilterCase( # TODO Table 7
# "%PARAM_MULTI_BIT_GROUP_{gf}_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
# pfails=[
# ErrorFailureProbability.TWO_MINUS_64,
# ErrorFailureProbability.TWO_MINUS_128,
# ],
# grouping_factors=[
# GroupingFactor.Two,
# GroupingFactor.Three,
# GroupingFactor.Four,
# ],
# ),
# ParametersFilterCase( # TODO Table 13
# "%COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_{pfail}",
# pfails=[
# ErrorFailureProbability.TWO_MINUS_64,
# ErrorFailureProbability.TWO_MINUS_128,
# ],
# ),
]
def _generate_latex_tables(
conn: connector.PostgreConnector,
user_config: config.UserConfig,
result_dir: pathlib.Path,
):
conversion_func = utils.convert_latency_value_to_readable_text
case_config = copy.deepcopy(user_config)
case_config.backend = config.Backend.CPU
case_config.pbs_kind = config.PBSKind.Any
# for case in CORE_CRYPTO_PARAM_CASES:
# case_config.layer = config.Layer.CoreCrypto
# param_patterns = case.get_parameter_variants()
# res = conn.fetch_benchmark_data(case_config, param_name_patterns=param_patterns)
#
# generic_formatter = CoreFormatter(
# case_config.layer,
# case_config.backend,
# case_config.pbs_kind,
# case_config.grouping_factor,
# )
# formatted_results = generic_formatter.format_data_with_available_sizes(
# res,
# conversion_func,
# )
#
# for k, v in formatted_results.items(): # DEBUG
# print(k)
# for sub_k, sub_v in v.items():
# print(f"\t{sub_k}: {sub_v}")
# print("")
#
# print("--------------------------------------------------")
# print("--------------------------------------------------")
# print("--------------------------------------------------")
#
# for table in case.associated_tables:
# formatted_table = table.format_table(
# formatted_results
# )
# print(formatted_table) # DEBUG
for case in INTEGER_PARAM_CASES:
case_config.layer = config.Layer.Integer
param_patterns = case.get_parameter_variants()
res = conn.fetch_benchmark_data(case_config, param_name_patterns=param_patterns)
generic_formatter = IntegerFormatter(
case_config.layer,
case_config.backend,
case_config.pbs_kind,
case_config.grouping_factor,
)
formatted_results = generic_formatter.format_data_with_available_sizes(
res,
conversion_func,
)
# FIXME il faut qu'on puisse avoir accès aux définitions de paramètres après formattage afin de pouvoir filtrer
# ensuite sans quoi on ne peut pas faire de parameters matching
for k, v in formatted_results.items(): # DEBUG
print(k)
for sub_k, sub_v in v.items():
print(f"\t{sub_k}: {sub_v}")
print("")
print("--------------------------------------------------")
print("--------------------------------------------------")
print("--------------------------------------------------")
for table in case.associated_tables:
formatted_table = table.format_table(formatted_results)
# TODO écrire chaque table dans un fichier en récupérant le __name__ de la table, le lower() et strip() le préfixe "table_"
print(formatted_table) # DEBUG
# TODO prendre la valeur minimum dans un groupe d'opération (ex: min/max, gt/ge/lt/le, ...)
def perform_latex_generation(
conn: connector.PostgreConnector, user_config: config.UserConfig
):
dir_path = user_config.output_file
try:
dir_path.mkdir(parents=True, exist_ok=True)
except Exception as e: ## TODO find the exact exception that can be raised here
pass
# TODO passer un dossier en user_config et enregistrer les tables latex dans ce dossier un fichier = une table
_generate_latex_tables(conn, user_config, dir_path)