chore(frontend-python): add test for runtime compiled function

This commit is contained in:
Andrei Stoian
2025-01-31 14:11:28 +01:00
parent 6bec27b270
commit 64d633f270

View File

@@ -5,11 +5,14 @@ Tests of `TFHERSIntegerType` data type.
import json
import os
import tempfile
from functools import partial
from types import FunctionType
import numpy as np
import pytest
from concrete.fhe import tfhers
from concrete.fhe.compilation.compiler import Compiler
DEFAULT_TFHERS_PARAM = tfhers.CryptoParams(
909,
@@ -198,3 +201,50 @@ def test_load_tfhers_params_file(params_dict):
tfhers.get_type_from_params(fpath, True, 8)
os.unlink(fpath)
@pytest.mark.parametrize("n_args", [1, 3])
def test_compile_runtime_defined_function(n_args):
arg_names = [chr(v) for v in range(ord("a"), ord("a") + n_args)]
proxy_func_arg_string = ", ".join(arg_names)
proxy_func_name = "_proxy"
dtype = tfhers.uint8_2_2(DEFAULT_TFHERS_PARAM)
def function_to_proxy_3(a, b, c):
return a + b + c
def function_to_proxy_1(a):
return a + a
func_to_proxy_str = "function_to_proxy_1" if n_args == 1 else "function_to_proxy_3"
function_to_proxy_code_str = (
f"def {proxy_func_name}({proxy_func_arg_string}): \n"
f" from concrete.fhe import tfhers\n"
f" native_inputs = tuple((tfhers.to_native(v) for v in [{proxy_func_arg_string}]))\n"
f" res = {func_to_proxy_str}(*native_inputs)\n"
f" out = tuple((tfhers.from_native(v, dtype) for v in res)) if isinstance(res, tuple) else tfhers.from_native(res, dtype) \n"
f" return out"
)
function_proxy_code = compile(
function_to_proxy_code_str,
__file__,
mode="exec",
)
function_proxy = FunctionType(function_proxy_code.co_consts[0], locals(), proxy_func_name)
inputs_encryption_status = {v: "encrypted" for v in arg_names}
compiler = Compiler(
function_proxy,
parameter_encryption_statuses=inputs_encryption_status,
)
input_vals = list(range(1, n_args + 1))
inputset = [tuple(tfhers.TFHERSInteger(dtype, v) for v in input_vals)]
compiler.compile(
inputset,
)