diff --git a/ci/data_extractor/src/comparison.py b/ci/data_extractor/src/comparison.py index 7cb0ea459..eadb8aa47 100644 --- a/ci/data_extractor/src/comparison.py +++ b/ci/data_extractor/src/comparison.py @@ -4,6 +4,8 @@ import config import connector from benchmark_specs import Backend, Layer, OperandType, PBSKind, RustType from formatters.common import OPERATION_SIZE_COLUMN_HEADER, GenericFormatter +from formatters.hlapi import HlApiFormatter +from formatters.integer import IntegerFormatter import utils @@ -51,13 +53,13 @@ def perform_backends_comparison( res = conn.fetch_benchmark_data(case_config) - generic_formatter = GenericFormatter( + integer_formatter = IntegerFormatter( case_config.layer, case_config.backend, case_config.pbs_kind, case_config.grouping_factor, ) - formatted_results = generic_formatter.format_data( + formatted_results = integer_formatter.format_data( res, conversion_func, ) @@ -67,13 +69,13 @@ def perform_backends_comparison( if backend == Backend.HPU: case_config.layer = Layer.HLApi hlapi_res = conn.fetch_benchmark_data(case_config) - hlapi_generic_formatter = GenericFormatter( + hlapi_formatter = HlApiFormatter( case_config.layer, case_config.backend, case_config.pbs_kind, case_config.grouping_factor, ) - hlapi_formatted_results = hlapi_generic_formatter.format_data( + hlapi_formatted_results = hlapi_formatter.format_data( hlapi_res, conversion_func, ) @@ -84,7 +86,7 @@ def perform_backends_comparison( if k in integer_sizes_fetched } - generic_arrays = generic_formatter.generate_array( + generic_arrays = integer_formatter.generate_array( formatted_results, OperandType.CipherText, included_types=[ diff --git a/ci/data_extractor/src/formatters/common.py b/ci/data_extractor/src/formatters/common.py index 7858be615..a87ccebe2 100644 --- a/ci/data_extractor/src/formatters/common.py +++ b/ci/data_extractor/src/formatters/common.py @@ -187,7 +187,9 @@ class GenericFormatter: @staticmethod def _format_data(*args, **kwargs): # Must be implemented by subclasses - raise NotImplementedError + raise NotImplementedError( + f"format_data() not implemented for this formatter: '{__class__.__name__}'" + ) def generate_array( self, @@ -231,7 +233,9 @@ class GenericFormatter: def _generate_arrays(self, *args, **kwargs): # Must be implemented by subclasses - raise NotImplementedError + raise NotImplementedError( + f"generate_arrays() not implemented for this formatter: '{__class__.__name__}'" + ) class CSVFormatter(GenericFormatter):