mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
chore(ci): fix data extraction on backends comparison for docs
This commit is contained in:
@@ -4,6 +4,8 @@ import config
|
|||||||
import connector
|
import connector
|
||||||
from benchmark_specs import Backend, Layer, OperandType, PBSKind, RustType
|
from benchmark_specs import Backend, Layer, OperandType, PBSKind, RustType
|
||||||
from formatters.common import OPERATION_SIZE_COLUMN_HEADER, GenericFormatter
|
from formatters.common import OPERATION_SIZE_COLUMN_HEADER, GenericFormatter
|
||||||
|
from formatters.hlapi import HlApiFormatter
|
||||||
|
from formatters.integer import IntegerFormatter
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
@@ -51,13 +53,13 @@ def perform_backends_comparison(
|
|||||||
|
|
||||||
res = conn.fetch_benchmark_data(case_config)
|
res = conn.fetch_benchmark_data(case_config)
|
||||||
|
|
||||||
generic_formatter = GenericFormatter(
|
integer_formatter = IntegerFormatter(
|
||||||
case_config.layer,
|
case_config.layer,
|
||||||
case_config.backend,
|
case_config.backend,
|
||||||
case_config.pbs_kind,
|
case_config.pbs_kind,
|
||||||
case_config.grouping_factor,
|
case_config.grouping_factor,
|
||||||
)
|
)
|
||||||
formatted_results = generic_formatter.format_data(
|
formatted_results = integer_formatter.format_data(
|
||||||
res,
|
res,
|
||||||
conversion_func,
|
conversion_func,
|
||||||
)
|
)
|
||||||
@@ -67,13 +69,13 @@ def perform_backends_comparison(
|
|||||||
if backend == Backend.HPU:
|
if backend == Backend.HPU:
|
||||||
case_config.layer = Layer.HLApi
|
case_config.layer = Layer.HLApi
|
||||||
hlapi_res = conn.fetch_benchmark_data(case_config)
|
hlapi_res = conn.fetch_benchmark_data(case_config)
|
||||||
hlapi_generic_formatter = GenericFormatter(
|
hlapi_formatter = HlApiFormatter(
|
||||||
case_config.layer,
|
case_config.layer,
|
||||||
case_config.backend,
|
case_config.backend,
|
||||||
case_config.pbs_kind,
|
case_config.pbs_kind,
|
||||||
case_config.grouping_factor,
|
case_config.grouping_factor,
|
||||||
)
|
)
|
||||||
hlapi_formatted_results = hlapi_generic_formatter.format_data(
|
hlapi_formatted_results = hlapi_formatter.format_data(
|
||||||
hlapi_res,
|
hlapi_res,
|
||||||
conversion_func,
|
conversion_func,
|
||||||
)
|
)
|
||||||
@@ -84,7 +86,7 @@ def perform_backends_comparison(
|
|||||||
if k in integer_sizes_fetched
|
if k in integer_sizes_fetched
|
||||||
}
|
}
|
||||||
|
|
||||||
generic_arrays = generic_formatter.generate_array(
|
generic_arrays = integer_formatter.generate_array(
|
||||||
formatted_results,
|
formatted_results,
|
||||||
OperandType.CipherText,
|
OperandType.CipherText,
|
||||||
included_types=[
|
included_types=[
|
||||||
|
|||||||
@@ -187,7 +187,9 @@ class GenericFormatter:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_data(*args, **kwargs):
|
def _format_data(*args, **kwargs):
|
||||||
# Must be implemented by subclasses
|
# Must be implemented by subclasses
|
||||||
raise NotImplementedError
|
raise NotImplementedError(
|
||||||
|
f"format_data() not implemented for this formatter: '{__class__.__name__}'"
|
||||||
|
)
|
||||||
|
|
||||||
def generate_array(
|
def generate_array(
|
||||||
self,
|
self,
|
||||||
@@ -231,7 +233,9 @@ class GenericFormatter:
|
|||||||
|
|
||||||
def _generate_arrays(self, *args, **kwargs):
|
def _generate_arrays(self, *args, **kwargs):
|
||||||
# Must be implemented by subclasses
|
# Must be implemented by subclasses
|
||||||
raise NotImplementedError
|
raise NotImplementedError(
|
||||||
|
f"generate_arrays() not implemented for this formatter: '{__class__.__name__}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CSVFormatter(GenericFormatter):
|
class CSVFormatter(GenericFormatter):
|
||||||
|
|||||||
Reference in New Issue
Block a user