From c51c4bd17a8b4ee91ad6ce471a288879e04e5641 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 9 Aug 2021 17:28:30 +0200 Subject: [PATCH] feat(tracing-astype): add astype method on NPTracer --- hdk/hnumpy/tracing.py | 30 ++++++++++++ tests/hnumpy/test_tracing.py | 92 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 8072fdc18..0053e1543 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -1,14 +1,44 @@ """hnumpy tracing utilities""" from typing import Callable, Dict +import numpy +from numpy.typing import DTypeLike + from ..common.data_types import BaseValue from ..common.operator_graph import OPGraph +from ..common.representation import intermediate as ir from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters +from .np_dtypes_helpers import convert_numpy_dtype_to_common_dtype class NPTracer(BaseTracer): """Tracer class for numpy operations""" + def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer": + """Support numpy astype feature, for now it only accepts a dtype and no additional + parameters, *args and **kwargs are accepted for interface compatibility only + + Args: + numpy_dtype (DTypeLike): The object describing a numpy type + + Returns: + NPTracer: The NPTracer representing the casting operation + """ + assert len(args) == 0, f"astype currently only supports tracing without *args, got {args}" + assert ( + len(kwargs) == 0 + ), f"astype currently only supports tracing without **kwargs, got {kwargs}" + + normalized_numpy_dtype = numpy.dtype(numpy_dtype) + output_dtype = convert_numpy_dtype_to_common_dtype(numpy_dtype) + traced_computation = ir.ArbitraryFunction( + input_base_value=self.output, + arbitrary_func=normalized_numpy_dtype.type, + output_dtype=output_dtype, + ) + output_tracer = NPTracer([self], traced_computation=traced_computation, output_index=0) + return output_tracer + def trace_numpy_function( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 4dfada2d6..2634560d2 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -1,8 +1,10 @@ """Test file for hnumpy tracing""" import networkx as nx +import numpy import pytest +from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import ClearValue, EncryptedValue from hdk.common.representation import intermediate as ir @@ -109,3 +111,93 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): ref_graph.add_edge(input_y, returned_final_node, input_idx=1) assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) + + +@pytest.mark.parametrize( + "function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples", + [ + ( + lambda x: x.astype(numpy.int32), + Integer(32, is_signed=True), + [ + (14, numpy.int32(14)), + (1.5, numpy.int32(1)), + (2.0, numpy.int32(2)), + (-1.5, numpy.int32(-1)), + (2 ** 31 - 1, numpy.int32(2 ** 31 - 1)), + (-(2 ** 31), numpy.int32(-(2 ** 31))), + ], + ), + ( + lambda x: x.astype(numpy.uint32), + Integer(32, is_signed=False), + [ + (14, numpy.uint32(14)), + (1.5, numpy.uint32(1)), + (2.0, numpy.uint32(2)), + (2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)), + ], + ), + ( + lambda x: x.astype(numpy.int64), + Integer(64, is_signed=True), + [ + (14, numpy.int64(14)), + (1.5, numpy.int64(1)), + (2.0, numpy.int64(2)), + (-1.5, numpy.int64(-1)), + (2 ** 63 - 1, numpy.int64(2 ** 63 - 1)), + (-(2 ** 63), numpy.int64(-(2 ** 63))), + ], + ), + ( + lambda x: x.astype(numpy.uint64), + Integer(64, is_signed=False), + [ + (14, numpy.uint64(14)), + (1.5, numpy.uint64(1)), + (2.0, numpy.uint64(2)), + (2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)), + ], + ), + ( + lambda x: x.astype(numpy.float64), + Float(64), + [ + (14, numpy.float64(14.0)), + (1.5, numpy.float64(1.5)), + (2.0, numpy.float64(2.0)), + (-1.5, numpy.float64(-1.5)), + ], + ), + ( + lambda x: x.astype(numpy.float32), + Float(32), + [ + (14, numpy.float32(14.0)), + (1.5, numpy.float32(1.5)), + (2.0, numpy.float32(2.0)), + (-1.5, numpy.float32(-1.5)), + ], + ), + ], +) +def test_tracing_astype( + function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples +): + """Test function for NPTracer.astype""" + for input_, expected_output in input_and_expected_output_tuples: + input_value = ( + EncryptedValue(Integer(64, is_signed=True)) + if isinstance(input_, int) + else EncryptedValue(Float(64)) + ) + + op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) + output_node = op_graph.output_nodes[0] + assert op_graph_expected_output_type == output_node.outputs[0].data_type + + node_results = op_graph.evaluate({0: numpy.array(input_)}) + evaluated_output = node_results[output_node] + assert isinstance(evaluated_output, type(expected_output)) + assert expected_output == evaluated_output