"""Test file for user-friendly numpy compilation functions""" import numpy import pytest from concrete.common.debugging import format_operation_graph from concrete.numpy.np_fhe_compiler import NPFHECompiler def complicated_topology(x, y): """Mix x in an intricated way.""" intermediate = x + y x_p_1 = intermediate + 1 x_p_2 = intermediate + 2 x_p_3 = x_p_1 + x_p_2 return ( x_p_3.astype(numpy.int32), x_p_2.astype(numpy.int32), (x_p_2 + 3).astype(numpy.int32), x_p_3.astype(numpy.int32) + 67, ) @pytest.mark.parametrize("input_shape", [(), (3, 1, 2)]) def test_np_fhe_compiler_op_graph( input_shape, default_compilation_configuration, check_array_equality ): """Test NPFHECompiler in two subtests.""" subtest_np_fhe_compiler_1_input_op_graph( input_shape, default_compilation_configuration, check_array_equality ) subtest_np_fhe_compiler_2_inputs_op_graph( input_shape, default_compilation_configuration, check_array_equality ) def subtest_np_fhe_compiler_1_input_op_graph( input_shape, default_compilation_configuration, check_array_equality ): """test for NPFHECompiler on one input function""" def function_to_compile(x): return complicated_topology(x, 0) compiler = NPFHECompiler( function_to_compile, {"x": "encrypted"}, default_compilation_configuration, ) # For coverage when the OPGraph is not yet traced compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access assert compiler.compilation_configuration == default_compilation_configuration assert compiler.compilation_configuration is not default_compilation_configuration for i in numpy.arange(5): i = numpy.ones(input_shape, dtype=numpy.int64) * i check_array_equality(compiler(i), function_to_compile(i)) # For coverage, check that we flush the inputset when we query the OPGraph current_op_graph = compiler.op_graph assert current_op_graph is not compiler.op_graph assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access # For coverage, cover case where the current inputset is empty compiler._eval_on_current_inputset() # pylint: disable=protected-access # Continue a bit more for i in numpy.arange(5, 10): i = numpy.ones(input_shape, dtype=numpy.int64) * i check_array_equality(compiler(i), function_to_compile(i)) if input_shape == (): assert ( (got := format_operation_graph(compiler.op_graph)) == """ %0 = 67 # ClearScalar %1 = 2 # ClearScalar %2 = 3 # ClearScalar %3 = 1 # ClearScalar %4 = x # EncryptedScalar %5 = 0 # ClearScalar %6 = add(%4, %5) # EncryptedScalar %7 = add(%6, %1) # EncryptedScalar %8 = add(%6, %3) # EncryptedScalar %9 = astype(%7, dtype=int32) # EncryptedScalar %10 = add(%7, %2) # EncryptedScalar %11 = add(%8, %7) # EncryptedScalar %12 = astype(%10, dtype=int32) # EncryptedScalar %13 = astype(%11, dtype=int32) # EncryptedScalar %14 = astype(%11, dtype=int32) # EncryptedScalar %15 = add(%14, %0) # EncryptedScalar (%13, %9, %12, %15)""" ), got else: assert ( (got := format_operation_graph(compiler.op_graph)) == """ %0 = 67 # ClearScalar %1 = 2 # ClearScalar %2 = 3 # ClearScalar %3 = 1 # ClearScalar %4 = x # EncryptedTensor %5 = 0 # ClearScalar %6 = add(%4, %5) # EncryptedTensor %7 = add(%6, %1) # EncryptedTensor %8 = add(%6, %3) # EncryptedTensor %9 = astype(%7, dtype=int32) # EncryptedTensor %10 = add(%7, %2) # EncryptedTensor %11 = add(%8, %7) # EncryptedTensor %12 = astype(%10, dtype=int32) # EncryptedTensor %13 = astype(%11, dtype=int32) # EncryptedTensor %14 = astype(%11, dtype=int32) # EncryptedTensor %15 = add(%14, %0) # EncryptedTensor (%13, %9, %12, %15)""" ), got def subtest_np_fhe_compiler_2_inputs_op_graph( input_shape, default_compilation_configuration, check_array_equality ): """test for NPFHECompiler on two inputs function""" compiler = NPFHECompiler( complicated_topology, {"x": "encrypted", "y": "clear"}, default_compilation_configuration, ) # For coverage when the OPGraph is not yet traced compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access assert compiler.compilation_configuration == default_compilation_configuration assert compiler.compilation_configuration is not default_compilation_configuration for i, j in zip(numpy.arange(5), numpy.arange(5, 10)): i = numpy.ones(input_shape, dtype=numpy.int64) * i j = numpy.ones(input_shape, dtype=numpy.int64) * j check_array_equality(compiler(i, j), complicated_topology(i, j)) # For coverage, check that we flush the inputset when we query the OPGraph current_op_graph = compiler.op_graph assert current_op_graph is not compiler.op_graph assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access # For coverage, cover case where the current inputset is empty compiler._eval_on_current_inputset() # pylint: disable=protected-access # Continue a bit more for i, j in zip(numpy.arange(5, 10), numpy.arange(5)): i = numpy.ones(input_shape, dtype=numpy.int64) * i j = numpy.ones(input_shape, dtype=numpy.int64) * j check_array_equality(compiler(i, j), complicated_topology(i, j)) if input_shape == (): assert ( (got := format_operation_graph(compiler.op_graph)) == """ %0 = 67 # ClearScalar %1 = 2 # ClearScalar %2 = 3 # ClearScalar %3 = 1 # ClearScalar %4 = x # EncryptedScalar %5 = y # ClearScalar %6 = add(%4, %5) # EncryptedScalar %7 = add(%6, %1) # EncryptedScalar %8 = add(%6, %3) # EncryptedScalar %9 = astype(%7, dtype=int32) # EncryptedScalar %10 = add(%7, %2) # EncryptedScalar %11 = add(%8, %7) # EncryptedScalar %12 = astype(%10, dtype=int32) # EncryptedScalar %13 = astype(%11, dtype=int32) # EncryptedScalar %14 = astype(%11, dtype=int32) # EncryptedScalar %15 = add(%14, %0) # EncryptedScalar (%13, %9, %12, %15)""" ), got else: assert ( (got := format_operation_graph(compiler.op_graph)) == """ %0 = 67 # ClearScalar %1 = 2 # ClearScalar %2 = 3 # ClearScalar %3 = 1 # ClearScalar %4 = x # EncryptedTensor %5 = y # ClearTensor %6 = add(%4, %5) # EncryptedTensor %7 = add(%6, %1) # EncryptedTensor %8 = add(%6, %3) # EncryptedTensor %9 = astype(%7, dtype=int32) # EncryptedTensor %10 = add(%7, %2) # EncryptedTensor %11 = add(%8, %7) # EncryptedTensor %12 = astype(%10, dtype=int32) # EncryptedTensor %13 = astype(%11, dtype=int32) # EncryptedTensor %14 = astype(%11, dtype=int32) # EncryptedTensor %15 = add(%14, %0) # EncryptedTensor (%13, %9, %12, %15)""" ), got def remaining_inputset_size(inputset_len): """Small function to generate test cases below for remaining inputset length.""" return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE @pytest.mark.parametrize( "inputset_len, expected_remaining_inputset_len", [ (42, remaining_inputset_size(42)), (128, remaining_inputset_size(128)), (234, remaining_inputset_size(234)), ], ) def test_np_fhe_compiler_auto_flush( inputset_len, expected_remaining_inputset_len, default_compilation_configuration, check_array_equality, ): """Test the auto flush of NPFHECompiler once the inputset is 128 elements.""" def function_to_compile(x): return x // 2 compiler = NPFHECompiler( function_to_compile, {"x": "encrypted"}, default_compilation_configuration, ) for i in numpy.arange(inputset_len): check_array_equality(compiler(i), function_to_compile(i)) # Check the inputset was properly flushed assert ( len(compiler._current_inputset) # pylint: disable=protected-access == expected_remaining_inputset_len ) def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality): """Test the case where we generate an FHE circuit.""" def function_to_compile(x): return x + 42 compiler = NPFHECompiler( function_to_compile, {"x": "encrypted"}, default_compilation_configuration, ) # For coverage with pytest.raises(RuntimeError) as excinfo: compiler.get_compiled_fhe_circuit() assert str(excinfo.value) == ( "Requested FHECircuit but no OPGraph was compiled. " "Did you forget to evaluate NPFHECompiler over an inputset?" ) for i in numpy.arange(64): check_array_equality(compiler(i), function_to_compile(i)) fhe_circuit = compiler.get_compiled_fhe_circuit() for i in range(64): assert fhe_circuit.run(i) == function_to_compile(i)