"""Test file for hnumpy tracing""" import networkx as nx import numpy import pytest from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import ClearValue, EncryptedValue from hdk.common.representation import intermediate as ir from hdk.hnumpy import tracing @pytest.mark.parametrize( "operation", [ir.Add, ir.Sub, ir.Mul], ) @pytest.mark.parametrize( "x", [ pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="x: Encrypted uint"), pytest.param( EncryptedValue(Integer(64, is_signed=True)), id="x: Encrypted int", ), pytest.param( ClearValue(Integer(64, is_signed=False)), id="x: Clear uint", ), pytest.param( ClearValue(Integer(64, is_signed=True)), id="x: Clear int", ), ], ) @pytest.mark.parametrize( "y", [ pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="y: Encrypted uint"), pytest.param( EncryptedValue(Integer(64, is_signed=True)), id="y: Encrypted int", ), pytest.param( ClearValue(Integer(64, is_signed=False)), id="y: Clear uint", ), pytest.param( ClearValue(Integer(64, is_signed=True)), id="y: Clear int", ), ], ) def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): "Test hnumpy tracing a binary operation (in the supported ops)" # Remark that the functions here have a common structure (which is # 2x op y), such that creating further the ref_graph is easy, by # hand def simple_add_function(x, y): z = x + x return z + y def simple_sub_function(x, y): z = x + x return z - y def simple_mul_function(x, y): z = x + x return z * y if operation == ir.Add: function_to_compile = simple_add_function elif operation == ir.Sub: function_to_compile = simple_sub_function elif operation == ir.Mul: function_to_compile = simple_mul_function else: assert False, f"unknown operation {operation}" op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y}) ref_graph = nx.MultiDiGraph() input_x = ir.Input(x, input_name="x", program_input_idx=0) input_y = ir.Input(y, input_name="y", program_input_idx=1) add_node_z = ir.Add( ( input_x.outputs[0], input_x.outputs[0], ) ) returned_final_node = operation( ( add_node_z.outputs[0], input_y.outputs[0], ) ) ref_graph.add_node(input_x, content=input_x) ref_graph.add_node(input_y, content=input_y) ref_graph.add_node(add_node_z, content=add_node_z) ref_graph.add_node(returned_final_node, content=returned_final_node) ref_graph.add_edge(input_x, add_node_z, input_idx=0) ref_graph.add_edge(input_x, add_node_z, input_idx=1) ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0) ref_graph.add_edge(input_y, returned_final_node, input_idx=1) assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) @pytest.mark.parametrize( "function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples", [ ( lambda x: x.astype(numpy.int32), Integer(32, is_signed=True), [ (14, numpy.int32(14)), (1.5, numpy.int32(1)), (2.0, numpy.int32(2)), (-1.5, numpy.int32(-1)), (2 ** 31 - 1, numpy.int32(2 ** 31 - 1)), (-(2 ** 31), numpy.int32(-(2 ** 31))), ], ), ( lambda x: x.astype(numpy.uint32), Integer(32, is_signed=False), [ (14, numpy.uint32(14)), (1.5, numpy.uint32(1)), (2.0, numpy.uint32(2)), (2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)), ], ), ( lambda x: x.astype(numpy.int64), Integer(64, is_signed=True), [ (14, numpy.int64(14)), (1.5, numpy.int64(1)), (2.0, numpy.int64(2)), (-1.5, numpy.int64(-1)), (2 ** 63 - 1, numpy.int64(2 ** 63 - 1)), (-(2 ** 63), numpy.int64(-(2 ** 63))), ], ), ( lambda x: x.astype(numpy.uint64), Integer(64, is_signed=False), [ (14, numpy.uint64(14)), (1.5, numpy.uint64(1)), (2.0, numpy.uint64(2)), (2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)), ], ), ( lambda x: x.astype(numpy.float64), Float(64), [ (14, numpy.float64(14.0)), (1.5, numpy.float64(1.5)), (2.0, numpy.float64(2.0)), (-1.5, numpy.float64(-1.5)), ], ), ( lambda x: x.astype(numpy.float32), Float(32), [ (14, numpy.float32(14.0)), (1.5, numpy.float32(1.5)), (2.0, numpy.float32(2.0)), (-1.5, numpy.float32(-1.5)), ], ), ], ) def test_tracing_astype( function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples ): """Test function for NPTracer.astype""" for input_, expected_output in input_and_expected_output_tuples: input_value = ( EncryptedValue(Integer(64, is_signed=True)) if isinstance(input_, int) else EncryptedValue(Float(64)) ) op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) output_node = op_graph.output_nodes[0] assert op_graph_expected_output_type == output_node.outputs[0].data_type node_results = op_graph.evaluate({0: numpy.array(input_)}) evaluated_output = node_results[output_node] assert isinstance(evaluated_output, type(expected_output)) assert expected_output == evaluated_output