refactor(tests): Use f-strings in linalg_apply_lookup_table generator

This commit is contained in:
Quentin Bourgerie
2022-11-29 11:41:37 +01:00
parent 6eb4cec706
commit 3efa8eb2a9

View File

@@ -5,39 +5,35 @@ import argparse
def generate(args):
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
print("# /!\ THIS FILE HAS BEEN GENERATED THANKS THE end_to_end_levelled_gen.py scripts")
print("# /!\ THIS FILE HAS BEEN GENERATED")
np.random.seed(0)
for n_ct in args.n_ct:
for p in range(args.min_bitwidth, args.max_bitwidth+1):
max_value = (2 ** p) - 1
random_lut = np.random.randint(max_value+1, size=2**p)
# identity_apply_lookup_table
print(
"description: apply_lookup_table_{0}bits_{1}ct".format(p, n_ct))
print(f"description: apply_lookup_table_{p}bits_{n_ct}ct")
print("program: |")
print(
" func.func @main(%0: tensor<{1}x!FHE.eint<{0}>>) -> tensor<{1}x!FHE.eint<{0}>> {{".format(p, n_ct))
print(" %tlu = arith.constant dense<[{0}]> : tensor<{1}xi64>".format(
','.join(map(str, random_lut)), 2**p))
f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{")
print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>")
for i in range(0, args.n_lut):
print(
" %{4} = \"FHELinalg.apply_lookup_table\"(%{3}, %tlu): (tensor<{2}x!FHE.eint<{0}>>, tensor<{1}xi64>) -> (tensor<{2}x!FHE.eint<{0}>>)".format(p, 2**p, n_ct, i, i+1))
print(" return %{2}: tensor<{1}x!FHE.eint<{0}>>".format(
p, n_ct, args.n_lut))
print(f" %{i+1} = \"FHELinalg.apply_lookup_table\"(%{i}, %tlu):")
print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}xi64>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)")
print(f" return %{args.n_lut}: tensor<{n_ct}x!FHE.eint<{p}>>")
print(" }")
random_input = np.random.randint(max_value+1, size=n_ct)
print("tests:")
print(" - inputs:")
print(
" - tensor: [{0}]".format(','.join(map(str, random_input))))
print(" shape: [{0}]".format(n_ct))
print(f" - tensor: [{','.join(map(str, random_input))}]")
print(f" shape: [{n_ct}]")
outputs = random_input
for i in range(0, args.n_lut):
outputs = [random_lut[v] for v in outputs]
print(" outputs:")
print(" - tensor: [{0}]".format(','.join(map(str, outputs))))
print(" shape: [{0}]".format(n_ct))
print(f" - tensor: [{','.join(map(str, outputs))}]")
print(f" shape: [{n_ct}]")
print("---")