From 64d633f2700925c9e7c8cf62ba139e16d6bb05a3 Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Fri, 31 Jan 2025 14:11:28 +0100 Subject: [PATCH] chore(frontend-python): add test for runtime compiled function --- .../tests/dtypes/test_tfhers.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/frontends/concrete-python/tests/dtypes/test_tfhers.py b/frontends/concrete-python/tests/dtypes/test_tfhers.py index 2c5bc728a..c44d51e28 100644 --- a/frontends/concrete-python/tests/dtypes/test_tfhers.py +++ b/frontends/concrete-python/tests/dtypes/test_tfhers.py @@ -5,11 +5,14 @@ Tests of `TFHERSIntegerType` data type. import json import os import tempfile +from functools import partial +from types import FunctionType import numpy as np import pytest from concrete.fhe import tfhers +from concrete.fhe.compilation.compiler import Compiler DEFAULT_TFHERS_PARAM = tfhers.CryptoParams( 909, @@ -198,3 +201,50 @@ def test_load_tfhers_params_file(params_dict): tfhers.get_type_from_params(fpath, True, 8) os.unlink(fpath) + + +@pytest.mark.parametrize("n_args", [1, 3]) +def test_compile_runtime_defined_function(n_args): + arg_names = [chr(v) for v in range(ord("a"), ord("a") + n_args)] + proxy_func_arg_string = ", ".join(arg_names) + proxy_func_name = "_proxy" + + dtype = tfhers.uint8_2_2(DEFAULT_TFHERS_PARAM) + + def function_to_proxy_3(a, b, c): + return a + b + c + + def function_to_proxy_1(a): + return a + a + + func_to_proxy_str = "function_to_proxy_1" if n_args == 1 else "function_to_proxy_3" + + function_to_proxy_code_str = ( + f"def {proxy_func_name}({proxy_func_arg_string}): \n" + f" from concrete.fhe import tfhers\n" + f" native_inputs = tuple((tfhers.to_native(v) for v in [{proxy_func_arg_string}]))\n" + f" res = {func_to_proxy_str}(*native_inputs)\n" + f" out = tuple((tfhers.from_native(v, dtype) for v in res)) if isinstance(res, tuple) else tfhers.from_native(res, dtype) \n" + f" return out" + ) + + function_proxy_code = compile( + function_to_proxy_code_str, + __file__, + mode="exec", + ) + function_proxy = FunctionType(function_proxy_code.co_consts[0], locals(), proxy_func_name) + + inputs_encryption_status = {v: "encrypted" for v in arg_names} + + compiler = Compiler( + function_proxy, + parameter_encryption_statuses=inputs_encryption_status, + ) + + input_vals = list(range(1, n_args + 1)) + inputset = [tuple(tfhers.TFHERSInteger(dtype, v) for v in input_vals)] + + compiler.compile( + inputset, + )