mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-13 06:48:02 -05:00
135 lines
3.9 KiB
Python
135 lines
3.9 KiB
Python
"""
|
|
Benchmark primitive operations.
|
|
"""
|
|
|
|
# pylint: disable=import-error,cell-var-from-loop,redefined-outer-name
|
|
|
|
import random
|
|
|
|
import py_progress_tracker as progress
|
|
|
|
from concrete import fhe
|
|
|
|
targets = []
|
|
configuration = fhe.Configuration()
|
|
|
|
# Table Lookup
|
|
for bit_width in range(2, 8 + 1):
|
|
targets.append(
|
|
{
|
|
"id": f"table-lookup :: tlu[eint{bit_width}]",
|
|
"name": f"{bit_width}-bit table lookup",
|
|
"parameters": {
|
|
"function": lambda x: x // 2,
|
|
"encryption": {"x": "encrypted"},
|
|
"inputset": fhe.inputset(lambda _: random.randint(0, (2**bit_width) - 1)),
|
|
"configuration": configuration,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Encrypted Multiplication
|
|
for bit_width in range(2, 8 + 1):
|
|
targets.append(
|
|
{
|
|
"id": f"encrypted-multiplication :: eint{bit_width} * eint{bit_width}",
|
|
"name": f"{bit_width}-bit encrypted multiplication",
|
|
"parameters": {
|
|
"function": lambda x, y: x * y,
|
|
"encryption": {"x": "encrypted", "y": "encrypted"},
|
|
"inputset": fhe.inputset(
|
|
lambda _: random.randint(0, (2**bit_width) - 1),
|
|
lambda _: random.randint(0, (2**bit_width) - 1),
|
|
),
|
|
"configuration": configuration,
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
@progress.track(targets)
|
|
def main(function, encryption, inputset, configuration):
|
|
"""
|
|
Benchmark a target.
|
|
|
|
Args:
|
|
function:
|
|
function to benchmark
|
|
|
|
encryption:
|
|
encryption status of the arguments of the function
|
|
|
|
inputset:
|
|
inputset to use for compiling the function
|
|
|
|
configuration:
|
|
configuration to use for compilation
|
|
"""
|
|
|
|
compiler = fhe.Compiler(function, encryption)
|
|
|
|
print("Compiling...")
|
|
with progress.measure(id="compilation-time-ms", label="Compilation Time (ms)"):
|
|
circuit = compiler.compile(inputset, configuration)
|
|
|
|
progress.measure(
|
|
id="complexity",
|
|
label="Complexity",
|
|
value=circuit.complexity,
|
|
)
|
|
|
|
print("Generating keys...")
|
|
with progress.measure(id="key-generation-time-ms", label="Key Generation Time (ms)"):
|
|
circuit.keygen(force=True)
|
|
|
|
progress.measure(
|
|
id="evaluation-key-size-mb",
|
|
label="Evaluation Key Size (MB)",
|
|
value=(len(circuit.keys.evaluation.serialize()) / (1024 * 1024)),
|
|
)
|
|
|
|
# pylint: disable=unused-variable
|
|
|
|
print("Warming up...")
|
|
sample = random.choice(inputset)
|
|
encrypted = circuit.encrypt(*sample)
|
|
ran = circuit.run(encrypted)
|
|
decrypted = circuit.decrypt(ran) # noqa: F841
|
|
|
|
# pylint: enable=unused-variable
|
|
|
|
def calculate_input_output_size(input_output):
|
|
if isinstance(input_output, tuple):
|
|
result = sum(len(value.serialize()) for value in input_output)
|
|
else:
|
|
result = len(input_output.serialize())
|
|
return result / (1024 * 1024)
|
|
|
|
progress.measure(
|
|
id="input-ciphertext-size-mb",
|
|
label="Input Ciphertext Size (MB)",
|
|
value=calculate_input_output_size(encrypted),
|
|
)
|
|
progress.measure(
|
|
id="output-ciphertext-size-mb",
|
|
label="Output Ciphertext Size (MB)",
|
|
value=calculate_input_output_size(ran),
|
|
)
|
|
|
|
for i in range(10):
|
|
print(f"Running subsample {i + 1} out of 10...")
|
|
|
|
sample = random.choice(inputset)
|
|
with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"):
|
|
encrypted = circuit.encrypt(*sample)
|
|
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"):
|
|
ran = circuit.run(encrypted)
|
|
with progress.measure(id="decryption-time-ms", label="Decryption Time (ms)"):
|
|
output = circuit.decrypt(ran)
|
|
|
|
progress.measure(
|
|
id="accuracy",
|
|
label="Accuracy",
|
|
value=int(output == function(*sample)),
|
|
)
|