diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index 22990bf3c..499c9a79f 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -21,7 +21,22 @@ from .wrapper import WrapperCpp # matches (@tag, separator( | ), filename) -REGEX_LOCATION = r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"" +REGEX_LOCATION = re.compile(r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"") + + +def tag_from_location(location): + """ + Extract tag of the operation from its location. + """ + + match = REGEX_LOCATION.match(location) + if match is not None: + tag, _, _ = match.groups() + # remove the @ + tag = tag[1:] if tag else "" + else: + tag = "" + return tag class CompilationFeedback(WrapperCpp): @@ -135,9 +150,7 @@ class CompilationFeedback(WrapperCpp): if statistic.operation not in operations: continue - tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups() - # remove the @ - tag = tag[1:] if tag else "" + tag = tag_from_location(statistic.location) tag_components = tag.split(".") for i in range(1, len(tag_components) + 1): @@ -182,9 +195,7 @@ class CompilationFeedback(WrapperCpp): if statistic.operation not in operations: continue - tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups() - # remove the @ - tag = tag[1:] if tag else "" + tag = tag_from_location(statistic.location) tag_components = tag.split(".") for i in range(1, len(tag_components) + 1): diff --git a/compilers/concrete-compiler/compiler/tests/python/test_statistics.py b/compilers/concrete-compiler/compiler/tests/python/test_statistics.py new file mode 100644 index 000000000..faf69d46f --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/python/test_statistics.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest +import shutil +import tempfile + +from concrete.compiler import ( + ClientSupport, + EvaluationKeys, + KeySet, + LibrarySupport, + PublicArguments, + PublicResult, +) +from mlir._mlir_libs._concretelang._compiler import KeyType, PrimitiveOperation + + +def test_statistics(): + mlir = """ + +module { + func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<6> { + %cst = arith.constant dense<[0, 1, 4, 9, 16, 25, 36, 49]> : tensor<8xi64> + %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<6> + return %0 : !FHE.eint<6> + } +} + + """.strip() + + with tempfile.TemporaryDirectory() as tmpdirname: + support = LibrarySupport.new(str(tmpdirname)) + compilation_result = support.compile(mlir) + + client_parameters = support.load_client_parameters(compilation_result) + compilation_feedback = support.load_compilation_feedback(compilation_result) + + pbs_count = compilation_feedback.count( + operations={ + PrimitiveOperation.PBS, + PrimitiveOperation.WOP_PBS, + } + ) + assert pbs_count == 1 + + pbs_counts_per_parameter = compilation_feedback.count_per_parameter( + operations={ + PrimitiveOperation.PBS, + PrimitiveOperation.WOP_PBS, + }, + key_types={KeyType.BOOTSTRAP}, + client_parameters=client_parameters, + ) + assert len(pbs_counts_per_parameter) == 1 + assert pbs_counts_per_parameter[list(pbs_counts_per_parameter.keys())[0]] == 1 + + pbs_counts_per_tag = compilation_feedback.count_per_tag( + operations={ + PrimitiveOperation.PBS, + PrimitiveOperation.WOP_PBS, + } + ) + assert pbs_counts_per_tag == {} + + pbs_counts_per_tag_per_parameter = ( + compilation_feedback.count_per_tag_per_parameter( + operations={ + PrimitiveOperation.PBS, + PrimitiveOperation.WOP_PBS, + }, + key_types={KeyType.BOOTSTRAP}, + client_parameters=client_parameters, + ) + ) + assert pbs_counts_per_tag_per_parameter == {}