Files
concrete/benchmarks/multi_table_lookup.py

61 lines
1.6 KiB
Python

# bench: Unit Target: Multi Table Lookup
import math
import numpy as np
from common import BENCHMARK_CONFIGURATION
import concrete.numpy as hnp
def main():
input_bits = 3
square_table = hnp.LookupTable([i ** 2 for i in range(2 ** input_bits)])
sqrt_table = hnp.LookupTable([int(math.sqrt(i)) for i in range(2 ** input_bits)])
multi_table = hnp.MultiLookupTable(
[
[square_table, sqrt_table],
[square_table, sqrt_table],
[square_table, sqrt_table],
]
)
def function_to_compile(x):
return multi_table[x]
x = hnp.EncryptedTensor(hnp.UnsignedInteger(input_bits), shape=(3, 2))
# bench: Measure: Compilation Time (ms)
engine = hnp.compile_numpy_function(
function_to_compile,
{"x": x},
[np.random.randint(0, 2 ** input_bits, size=(3, 2)) for _ in range(32)],
compilation_configuration=BENCHMARK_CONFIGURATION,
)
# bench: Measure: End
inputs = []
labels = []
for _ in range(50):
sample_x = np.random.randint(0, 2 ** input_bits, size=(3, 2), dtype=np.uint8)
inputs.append([sample_x])
labels.append(function_to_compile(*inputs[-1]))
correct = 0
for input_i, label_i in zip(inputs, labels):
# bench: Measure: Evaluation Time (ms)
result_i = engine.run(*input_i)
# bench: Measure: End
if np.array_equal(result_i, label_i):
correct += 1
# bench: Measure: Accuracy (%) = (correct / len(inputs)) * 100
# bench: Alert: Accuracy (%) < 99
if __name__ == "__main__":
main()