mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-22 19:28:05 -05:00
229 lines
7.0 KiB
Python
229 lines
7.0 KiB
Python
import argparse
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
from concrete import fhe
|
|
|
|
|
|
# Hamming weight computation
|
|
def hw(x):
|
|
# Hamming Weight table for 8b entries
|
|
hw_table_ref = [np.binary_repr(x).count("1") for x in range(2**8)]
|
|
|
|
# fmt: off
|
|
assert np.array_equal(hw_table_ref, [
|
|
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3,
|
|
4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4,
|
|
4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2,
|
|
3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5,
|
|
4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3,
|
|
4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3,
|
|
3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5,
|
|
6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
|
|
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5,
|
|
6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8])
|
|
# fmt: on
|
|
|
|
hw_table = fhe.LookupTable(hw_table_ref)
|
|
return hw_table[x]
|
|
|
|
|
|
# Reference function for tests
|
|
def dist_in_clear(x, y):
|
|
return np.sum(hw(x ^ y))
|
|
|
|
|
|
# For all bitsize_w
|
|
def dist_in_fhe_directly_from_cp(x, y):
|
|
return np.sum(hw(x ^ y))
|
|
|
|
|
|
# For bitsize_w == 1 only
|
|
def dist_in_fhe_with_bits_1b(x, y):
|
|
z = x + y
|
|
zx = fhe.bits(z)[0]
|
|
return np.sum(zx)
|
|
|
|
|
|
# For all bitsize_w
|
|
def dist_in_fhe_with_xor_internal(x, y, bitsize_w):
|
|
power = 2**bitsize_w
|
|
table = fhe.LookupTable([hw((i % power) ^ (i // power)) for i in range(power**2)])
|
|
|
|
z = x + power * y
|
|
zx = table[z]
|
|
|
|
return np.sum(zx)
|
|
|
|
|
|
# For all bitsize_w
|
|
def dist_in_fhe_with_multivariate_internal(x, y):
|
|
zx = fhe.multivariate(lambda x, y: hw(x ^ y))(x, y)
|
|
return np.sum(zx)
|
|
|
|
|
|
# Manage user args
|
|
def manage_args():
|
|
parser = argparse.ArgumentParser(description="Hamming weight (aka XOR) distance in Concrete.")
|
|
parser.add_argument(
|
|
"--nb_bits",
|
|
dest="nb_bits",
|
|
action="store",
|
|
type=int,
|
|
default=120,
|
|
help="Number of bits (better to be a multiple of 12 to test all bitwidths)",
|
|
)
|
|
parser.add_argument(
|
|
"--shape",
|
|
dest="shape",
|
|
action="store",
|
|
type=int,
|
|
nargs="+",
|
|
default=None,
|
|
help="How to shape the bits. It has almost no importance for speed",
|
|
)
|
|
parser.add_argument(
|
|
"--show_mlir",
|
|
dest="show_mlir",
|
|
action="store_true",
|
|
help="Show the MLIR",
|
|
)
|
|
parser.add_argument(
|
|
"--repeat",
|
|
dest="repeat",
|
|
action="store",
|
|
type=int,
|
|
default=5,
|
|
help="Repeat x times",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
print()
|
|
|
|
# Options by the user
|
|
args = manage_args()
|
|
|
|
nb_bits = args.nb_bits
|
|
execution_times = {}
|
|
|
|
for bitsize_w in [1, 2, 3, 4]:
|
|
for algo in [
|
|
"dist_in_fhe_with_bits_1b",
|
|
"dist_in_fhe_with_xor_tables",
|
|
"dist_in_fhe_with_multivariate_tables",
|
|
"dist_in_fhe_directly_from_cp",
|
|
]:
|
|
if algo == "dist_in_fhe_with_bits_1b":
|
|
dist_function = dist_in_fhe_with_bits_1b
|
|
elif algo == "dist_in_fhe_with_xor_tables":
|
|
if bitsize_w == 1:
|
|
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 1)
|
|
elif bitsize_w == 2:
|
|
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 2)
|
|
elif bitsize_w == 3:
|
|
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 3)
|
|
elif bitsize_w == 4:
|
|
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 4)
|
|
elif algo == "dist_in_fhe_with_multivariate_tables":
|
|
dist_function = dist_in_fhe_with_multivariate_internal
|
|
else:
|
|
assert algo == "dist_in_fhe_directly_from_cp"
|
|
dist_function = dist_in_fhe_directly_from_cp
|
|
|
|
if algo == "dist_in_fhe_with_bits_1b" and bitsize_w != 1:
|
|
# Only work for 1b
|
|
continue
|
|
|
|
shape = (1, nb_bits // bitsize_w) if args.shape is None else tuple(args.shape)
|
|
|
|
# Checks
|
|
if nb_bits % bitsize_w != 0:
|
|
print(
|
|
f"Number of bits is not a multiple of w, can't test this "
|
|
f"configuration {algo} {bitsize_w}"
|
|
)
|
|
continue
|
|
|
|
assert (
|
|
np.prod(shape) * bitsize_w == nb_bits
|
|
), "Your (shape, w) does not correspond to number of bits"
|
|
|
|
# Info
|
|
print(
|
|
f"Computing XOR distance on {nb_bits} bits using algorithm {algo}, using vectors "
|
|
f"of shape {shape} of {bitsize_w}b cells"
|
|
)
|
|
|
|
# Compile the circuit
|
|
inputset = [
|
|
(
|
|
np.random.randint(2**bitsize_w, size=shape),
|
|
np.random.randint(2**bitsize_w, size=shape),
|
|
)
|
|
for _ in range(100)
|
|
]
|
|
|
|
compiler = fhe.Compiler(dist_function, {"x": "encrypted", "y": "encrypted"})
|
|
circuit = compiler.compile(
|
|
inputset,
|
|
show_mlir=args.show_mlir,
|
|
bitwise_strategy_preference=fhe.BitwiseStrategy.ONE_TLU_PROMOTED,
|
|
multivariate_strategy_preference=fhe.MultivariateStrategy.PROMOTED,
|
|
)
|
|
|
|
# Then generate the keys
|
|
circuit.keygen()
|
|
|
|
total_time = 0
|
|
|
|
# Then use
|
|
for _i in range(args.repeat):
|
|
# Take a random input pair
|
|
x, y = (
|
|
np.random.randint(2**bitsize_w, size=shape),
|
|
np.random.randint(2**bitsize_w, size=shape),
|
|
)
|
|
|
|
# Encrypt
|
|
encrypted_input = circuit.encrypt(x, y)
|
|
|
|
# Compute the distance in FHE
|
|
begin_time = time.time()
|
|
encrypted_result = circuit.run(encrypted_input)
|
|
end_time = time.time()
|
|
|
|
total_time += end_time - begin_time
|
|
|
|
# Decrypt
|
|
result = circuit.decrypt(encrypted_result)
|
|
|
|
# Check
|
|
assert result == dist_in_clear(x, y)
|
|
|
|
average_time = total_time / args.repeat
|
|
print(
|
|
f"Distance between encrypted vectors done in {average_time:.2f} "
|
|
f"seconds in average"
|
|
)
|
|
|
|
execution_times[f"{algo} on {bitsize_w} bits"] = average_time
|
|
|
|
print("")
|
|
|
|
# Final results
|
|
print("Results from the fastest to the slowest\n")
|
|
sorted_execution_times = sorted(execution_times.items(), key=lambda x: x[1])
|
|
|
|
for algo, average_time in sorted_execution_times:
|
|
print(f"{algo:>50s}: {average_time:5.2f} seconds")
|
|
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|