mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-05 04:44:41 -05:00
chore(ci): fix data extraction on backends comparison for docs
This commit is contained in:
@@ -4,6 +4,8 @@ import config
|
||||
import connector
|
||||
from benchmark_specs import Backend, Layer, OperandType, PBSKind, RustType
|
||||
from formatters.common import OPERATION_SIZE_COLUMN_HEADER, GenericFormatter
|
||||
from formatters.hlapi import HlApiFormatter
|
||||
from formatters.integer import IntegerFormatter
|
||||
|
||||
import utils
|
||||
|
||||
@@ -51,13 +53,13 @@ def perform_backends_comparison(
|
||||
|
||||
res = conn.fetch_benchmark_data(case_config)
|
||||
|
||||
generic_formatter = GenericFormatter(
|
||||
integer_formatter = IntegerFormatter(
|
||||
case_config.layer,
|
||||
case_config.backend,
|
||||
case_config.pbs_kind,
|
||||
case_config.grouping_factor,
|
||||
)
|
||||
formatted_results = generic_formatter.format_data(
|
||||
formatted_results = integer_formatter.format_data(
|
||||
res,
|
||||
conversion_func,
|
||||
)
|
||||
@@ -67,13 +69,13 @@ def perform_backends_comparison(
|
||||
if backend == Backend.HPU:
|
||||
case_config.layer = Layer.HLApi
|
||||
hlapi_res = conn.fetch_benchmark_data(case_config)
|
||||
hlapi_generic_formatter = GenericFormatter(
|
||||
hlapi_formatter = HlApiFormatter(
|
||||
case_config.layer,
|
||||
case_config.backend,
|
||||
case_config.pbs_kind,
|
||||
case_config.grouping_factor,
|
||||
)
|
||||
hlapi_formatted_results = hlapi_generic_formatter.format_data(
|
||||
hlapi_formatted_results = hlapi_formatter.format_data(
|
||||
hlapi_res,
|
||||
conversion_func,
|
||||
)
|
||||
@@ -84,7 +86,7 @@ def perform_backends_comparison(
|
||||
if k in integer_sizes_fetched
|
||||
}
|
||||
|
||||
generic_arrays = generic_formatter.generate_array(
|
||||
generic_arrays = integer_formatter.generate_array(
|
||||
formatted_results,
|
||||
OperandType.CipherText,
|
||||
included_types=[
|
||||
|
||||
@@ -187,7 +187,9 @@ class GenericFormatter:
|
||||
@staticmethod
|
||||
def _format_data(*args, **kwargs):
|
||||
# Must be implemented by subclasses
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
f"format_data() not implemented for this formatter: '{__class__.__name__}'"
|
||||
)
|
||||
|
||||
def generate_array(
|
||||
self,
|
||||
@@ -231,7 +233,9 @@ class GenericFormatter:
|
||||
|
||||
def _generate_arrays(self, *args, **kwargs):
|
||||
# Must be implemented by subclasses
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
f"generate_arrays() not implemented for this formatter: '{__class__.__name__}'"
|
||||
)
|
||||
|
||||
|
||||
class CSVFormatter(GenericFormatter):
|
||||
|
||||
Reference in New Issue
Block a user