mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 12:57:55 -05:00
chore(frontend-python): add test for runtime compiled function
This commit is contained in:
@@ -5,11 +5,14 @@ Tests of `TFHERSIntegerType` data type.
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from types import FunctionType
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.fhe import tfhers
|
||||
from concrete.fhe.compilation.compiler import Compiler
|
||||
|
||||
DEFAULT_TFHERS_PARAM = tfhers.CryptoParams(
|
||||
909,
|
||||
@@ -198,3 +201,50 @@ def test_load_tfhers_params_file(params_dict):
|
||||
|
||||
tfhers.get_type_from_params(fpath, True, 8)
|
||||
os.unlink(fpath)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_args", [1, 3])
|
||||
def test_compile_runtime_defined_function(n_args):
|
||||
arg_names = [chr(v) for v in range(ord("a"), ord("a") + n_args)]
|
||||
proxy_func_arg_string = ", ".join(arg_names)
|
||||
proxy_func_name = "_proxy"
|
||||
|
||||
dtype = tfhers.uint8_2_2(DEFAULT_TFHERS_PARAM)
|
||||
|
||||
def function_to_proxy_3(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
def function_to_proxy_1(a):
|
||||
return a + a
|
||||
|
||||
func_to_proxy_str = "function_to_proxy_1" if n_args == 1 else "function_to_proxy_3"
|
||||
|
||||
function_to_proxy_code_str = (
|
||||
f"def {proxy_func_name}({proxy_func_arg_string}): \n"
|
||||
f" from concrete.fhe import tfhers\n"
|
||||
f" native_inputs = tuple((tfhers.to_native(v) for v in [{proxy_func_arg_string}]))\n"
|
||||
f" res = {func_to_proxy_str}(*native_inputs)\n"
|
||||
f" out = tuple((tfhers.from_native(v, dtype) for v in res)) if isinstance(res, tuple) else tfhers.from_native(res, dtype) \n"
|
||||
f" return out"
|
||||
)
|
||||
|
||||
function_proxy_code = compile(
|
||||
function_to_proxy_code_str,
|
||||
__file__,
|
||||
mode="exec",
|
||||
)
|
||||
function_proxy = FunctionType(function_proxy_code.co_consts[0], locals(), proxy_func_name)
|
||||
|
||||
inputs_encryption_status = {v: "encrypted" for v in arg_names}
|
||||
|
||||
compiler = Compiler(
|
||||
function_proxy,
|
||||
parameter_encryption_statuses=inputs_encryption_status,
|
||||
)
|
||||
|
||||
input_vals = list(range(1, n_args + 1))
|
||||
inputset = [tuple(tfhers.TFHERSInteger(dtype, v) for v in input_vals)]
|
||||
|
||||
compiler.compile(
|
||||
inputset,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user