Files
concrete/frontends/concrete-python/examples/xor_distance/hamming_distance.py
2024-04-18 19:14:38 +02:00

229 lines
7.0 KiB
Python

import argparse
import time
import numpy as np
from concrete import fhe
# Hamming weight computation
def hw(x):
# Hamming Weight table for 8b entries
hw_table_ref = [np.binary_repr(x).count("1") for x in range(2**8)]
# fmt: off
assert np.array_equal(hw_table_ref, [
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3,
4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4,
4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2,
3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5,
4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3,
4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3,
3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5,
6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5,
6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8])
# fmt: on
hw_table = fhe.LookupTable(hw_table_ref)
return hw_table[x]
# Reference function for tests
def dist_in_clear(x, y):
return np.sum(hw(x ^ y))
# For all bitsize_w
def dist_in_fhe_directly_from_cp(x, y):
return np.sum(hw(x ^ y))
# For bitsize_w == 1 only
def dist_in_fhe_with_bits_1b(x, y):
z = x + y
zx = fhe.bits(z)[0]
return np.sum(zx)
# For all bitsize_w
def dist_in_fhe_with_xor_internal(x, y, bitsize_w):
power = 2**bitsize_w
table = fhe.LookupTable([hw((i % power) ^ (i // power)) for i in range(power**2)])
z = x + power * y
zx = table[z]
return np.sum(zx)
# For all bitsize_w
def dist_in_fhe_with_multivariate_internal(x, y):
zx = fhe.multivariate(lambda x, y: hw(x ^ y))(x, y)
return np.sum(zx)
# Manage user args
def manage_args():
parser = argparse.ArgumentParser(description="Hamming weight (aka XOR) distance in Concrete.")
parser.add_argument(
"--nb_bits",
dest="nb_bits",
action="store",
type=int,
default=120,
help="Number of bits (better to be a multiple of 12 to test all bitwidths)",
)
parser.add_argument(
"--shape",
dest="shape",
action="store",
type=int,
nargs="+",
default=None,
help="How to shape the bits. It has almost no importance for speed",
)
parser.add_argument(
"--show_mlir",
dest="show_mlir",
action="store_true",
help="Show the MLIR",
)
parser.add_argument(
"--repeat",
dest="repeat",
action="store",
type=int,
default=5,
help="Repeat x times",
)
args = parser.parse_args()
return args
def main():
print()
# Options by the user
args = manage_args()
nb_bits = args.nb_bits
execution_times = {}
for bitsize_w in [1, 2, 3, 4]:
for algo in [
"dist_in_fhe_with_bits_1b",
"dist_in_fhe_with_xor_tables",
"dist_in_fhe_with_multivariate_tables",
"dist_in_fhe_directly_from_cp",
]:
if algo == "dist_in_fhe_with_bits_1b":
dist_function = dist_in_fhe_with_bits_1b
elif algo == "dist_in_fhe_with_xor_tables":
if bitsize_w == 1:
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 1)
elif bitsize_w == 2:
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 2)
elif bitsize_w == 3:
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 3)
elif bitsize_w == 4:
dist_function = lambda x, y: dist_in_fhe_with_xor_internal(x, y, 4)
elif algo == "dist_in_fhe_with_multivariate_tables":
dist_function = dist_in_fhe_with_multivariate_internal
else:
assert algo == "dist_in_fhe_directly_from_cp"
dist_function = dist_in_fhe_directly_from_cp
if algo == "dist_in_fhe_with_bits_1b" and bitsize_w != 1:
# Only work for 1b
continue
shape = (1, nb_bits // bitsize_w) if args.shape is None else tuple(args.shape)
# Checks
if nb_bits % bitsize_w != 0:
print(
f"Number of bits is not a multiple of w, can't test this "
f"configuration {algo} {bitsize_w}"
)
continue
assert (
np.prod(shape) * bitsize_w == nb_bits
), "Your (shape, w) does not correspond to number of bits"
# Info
print(
f"Computing XOR distance on {nb_bits} bits using algorithm {algo}, using vectors "
f"of shape {shape} of {bitsize_w}b cells"
)
# Compile the circuit
inputset = [
(
np.random.randint(2**bitsize_w, size=shape),
np.random.randint(2**bitsize_w, size=shape),
)
for _ in range(100)
]
compiler = fhe.Compiler(dist_function, {"x": "encrypted", "y": "encrypted"})
circuit = compiler.compile(
inputset,
show_mlir=args.show_mlir,
bitwise_strategy_preference=fhe.BitwiseStrategy.ONE_TLU_PROMOTED,
multivariate_strategy_preference=fhe.MultivariateStrategy.PROMOTED,
)
# Then generate the keys
circuit.keygen()
total_time = 0
# Then use
for _i in range(args.repeat):
# Take a random input pair
x, y = (
np.random.randint(2**bitsize_w, size=shape),
np.random.randint(2**bitsize_w, size=shape),
)
# Encrypt
encrypted_input = circuit.encrypt(x, y)
# Compute the distance in FHE
begin_time = time.time()
encrypted_result = circuit.run(encrypted_input)
end_time = time.time()
total_time += end_time - begin_time
# Decrypt
result = circuit.decrypt(encrypted_result)
# Check
assert result == dist_in_clear(x, y)
average_time = total_time / args.repeat
print(
f"Distance between encrypted vectors done in {average_time:.2f} "
f"seconds in average"
)
execution_times[f"{algo} on {bitsize_w} bits"] = average_time
print("")
# Final results
print("Results from the fastest to the slowest\n")
sorted_execution_times = sorted(execution_times.items(), key=lambda x: x[1])
for algo, average_time in sorted_execution_times:
print(f"{algo:>50s}: {average_time:5.2f} seconds")
print()
if __name__ == "__main__":
main()