Files
concrete/frontends/concrete-python/benchmarks/primitive.py
2024-07-12 11:51:41 +02:00

135 lines
3.9 KiB
Python

"""
Benchmark primitive operations.
"""
# pylint: disable=import-error,cell-var-from-loop,redefined-outer-name
import random
import py_progress_tracker as progress
from concrete import fhe
targets = []
configuration = fhe.Configuration()
# Table Lookup
for bit_width in range(2, 8 + 1):
targets.append(
{
"id": f"table-lookup :: tlu[eint{bit_width}]",
"name": f"{bit_width}-bit table lookup",
"parameters": {
"function": lambda x: x // 2,
"encryption": {"x": "encrypted"},
"inputset": fhe.inputset(lambda _: random.randint(0, (2**bit_width) - 1)),
"configuration": configuration,
},
}
)
# Encrypted Multiplication
for bit_width in range(2, 8 + 1):
targets.append(
{
"id": f"encrypted-multiplication :: eint{bit_width} * eint{bit_width}",
"name": f"{bit_width}-bit encrypted multiplication",
"parameters": {
"function": lambda x, y: x * y,
"encryption": {"x": "encrypted", "y": "encrypted"},
"inputset": fhe.inputset(
lambda _: random.randint(0, (2**bit_width) - 1),
lambda _: random.randint(0, (2**bit_width) - 1),
),
"configuration": configuration,
},
}
)
@progress.track(targets)
def main(function, encryption, inputset, configuration):
"""
Benchmark a target.
Args:
function:
function to benchmark
encryption:
encryption status of the arguments of the function
inputset:
inputset to use for compiling the function
configuration:
configuration to use for compilation
"""
compiler = fhe.Compiler(function, encryption)
print("Compiling...")
with progress.measure(id="compilation-time-ms", label="Compilation Time (ms)"):
circuit = compiler.compile(inputset, configuration)
progress.measure(
id="complexity",
label="Complexity",
value=circuit.complexity,
)
print("Generating keys...")
with progress.measure(id="key-generation-time-ms", label="Key Generation Time (ms)"):
circuit.keygen(force=True)
progress.measure(
id="evaluation-key-size-mb",
label="Evaluation Key Size (MB)",
value=(len(circuit.keys.evaluation.serialize()) / (1024 * 1024)),
)
# pylint: disable=unused-variable
print("Warming up...")
sample = random.choice(inputset)
encrypted = circuit.encrypt(*sample)
ran = circuit.run(encrypted)
decrypted = circuit.decrypt(ran) # noqa: F841
# pylint: enable=unused-variable
def calculate_input_output_size(input_output):
if isinstance(input_output, tuple):
result = sum(len(value.serialize()) for value in input_output)
else:
result = len(input_output.serialize())
return result / (1024 * 1024)
progress.measure(
id="input-ciphertext-size-mb",
label="Input Ciphertext Size (MB)",
value=calculate_input_output_size(encrypted),
)
progress.measure(
id="output-ciphertext-size-mb",
label="Output Ciphertext Size (MB)",
value=calculate_input_output_size(ran),
)
for i in range(10):
print(f"Running subsample {i + 1} out of 10...")
sample = random.choice(inputset)
with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"):
encrypted = circuit.encrypt(*sample)
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"):
ran = circuit.run(encrypted)
with progress.measure(id="decryption-time-ms", label="Decryption Time (ms)"):
output = circuit.decrypt(ran)
progress.measure(
id="accuracy",
label="Accuracy",
value=int(output == function(*sample)),
)