mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
fix(compiler-bindings): don't crash on unknown location formats
This commit is contained in:
@@ -21,7 +21,22 @@ from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
# matches (@tag, separator( | ), filename)
|
||||
REGEX_LOCATION = r"loc\(\"(@[\w\.]+)?( \| )?(.+)\""
|
||||
REGEX_LOCATION = re.compile(r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"")
|
||||
|
||||
|
||||
def tag_from_location(location):
|
||||
"""
|
||||
Extract tag of the operation from its location.
|
||||
"""
|
||||
|
||||
match = REGEX_LOCATION.match(location)
|
||||
if match is not None:
|
||||
tag, _, _ = match.groups()
|
||||
# remove the @
|
||||
tag = tag[1:] if tag else ""
|
||||
else:
|
||||
tag = ""
|
||||
return tag
|
||||
|
||||
|
||||
class CompilationFeedback(WrapperCpp):
|
||||
@@ -135,9 +150,7 @@ class CompilationFeedback(WrapperCpp):
|
||||
if statistic.operation not in operations:
|
||||
continue
|
||||
|
||||
tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups()
|
||||
# remove the @
|
||||
tag = tag[1:] if tag else ""
|
||||
tag = tag_from_location(statistic.location)
|
||||
|
||||
tag_components = tag.split(".")
|
||||
for i in range(1, len(tag_components) + 1):
|
||||
@@ -182,9 +195,7 @@ class CompilationFeedback(WrapperCpp):
|
||||
if statistic.operation not in operations:
|
||||
continue
|
||||
|
||||
tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups()
|
||||
# remove the @
|
||||
tag = tag[1:] if tag else ""
|
||||
tag = tag_from_location(statistic.location)
|
||||
|
||||
tag_components = tag.split(".")
|
||||
for i in range(1, len(tag_components) + 1):
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from concrete.compiler import (
|
||||
ClientSupport,
|
||||
EvaluationKeys,
|
||||
KeySet,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import KeyType, PrimitiveOperation
|
||||
|
||||
|
||||
def test_statistics():
|
||||
mlir = """
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<6> {
|
||||
%cst = arith.constant dense<[0, 1, 4, 9, 16, 25, 36, 49]> : tensor<8xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""".strip()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
support = LibrarySupport.new(str(tmpdirname))
|
||||
compilation_result = support.compile(mlir)
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
compilation_feedback = support.load_compilation_feedback(compilation_result)
|
||||
|
||||
pbs_count = compilation_feedback.count(
|
||||
operations={
|
||||
PrimitiveOperation.PBS,
|
||||
PrimitiveOperation.WOP_PBS,
|
||||
}
|
||||
)
|
||||
assert pbs_count == 1
|
||||
|
||||
pbs_counts_per_parameter = compilation_feedback.count_per_parameter(
|
||||
operations={
|
||||
PrimitiveOperation.PBS,
|
||||
PrimitiveOperation.WOP_PBS,
|
||||
},
|
||||
key_types={KeyType.BOOTSTRAP},
|
||||
client_parameters=client_parameters,
|
||||
)
|
||||
assert len(pbs_counts_per_parameter) == 1
|
||||
assert pbs_counts_per_parameter[list(pbs_counts_per_parameter.keys())[0]] == 1
|
||||
|
||||
pbs_counts_per_tag = compilation_feedback.count_per_tag(
|
||||
operations={
|
||||
PrimitiveOperation.PBS,
|
||||
PrimitiveOperation.WOP_PBS,
|
||||
}
|
||||
)
|
||||
assert pbs_counts_per_tag == {}
|
||||
|
||||
pbs_counts_per_tag_per_parameter = (
|
||||
compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={
|
||||
PrimitiveOperation.PBS,
|
||||
PrimitiveOperation.WOP_PBS,
|
||||
},
|
||||
key_types={KeyType.BOOTSTRAP},
|
||||
client_parameters=client_parameters,
|
||||
)
|
||||
)
|
||||
assert pbs_counts_per_tag_per_parameter == {}
|
||||
Reference in New Issue
Block a user